mirror of https://github.com/ekimekim/wubloader
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
581 lines
21 KiB
Python
581 lines
21 KiB
Python
|
|
import base64
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import re
|
|
import string
|
|
import signal
|
|
import socket
|
|
import time
|
|
from calendar import timegm
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from itertools import count
|
|
|
|
import gevent.event
|
|
import gevent.queue
|
|
|
|
from common import atomic_write, listdir
|
|
from common.chat import BATCH_INTERVAL, format_batch, get_batch_files, merge_messages
|
|
from common.media import download_media, FailedResponse, WrongContent, Rejected
|
|
|
|
from girc import Client
|
|
from monotonic import monotonic
|
|
import prometheus_client as prom
|
|
import requests
|
|
|
|
from .keyed_group import KeyedGroup
|
|
|
|
|
|
# These are known to arrive up to MAX_DELAY after their actual time
|
|
DELAYED_COMMANDS = [
|
|
"JOIN",
|
|
"PART",
|
|
]
|
|
# This isn't documented, but we've observed up to 30sec of delay, so we pad a little extra
|
|
# and hope it's good enough.
|
|
MAX_DELAY = 45
|
|
|
|
COMMANDS = DELAYED_COMMANDS + [
|
|
"PRIVMSG",
|
|
"CLEARCHAT",
|
|
"CLEARMSG",
|
|
"HOSTTARGET",
|
|
"NOTICE",
|
|
"ROOMSTATE",
|
|
"USERNOTICE",
|
|
"USERSTATE",
|
|
]
|
|
|
|
# Assume we're never more than this amount of time behind the server time
|
|
# Worst case if too low: multiple output files for same batch that need merging later.
|
|
# Should be greater than MAX_DELAY.
|
|
MAX_SERVER_LAG = 60
|
|
|
|
# When guessing when a non-timestamped event occurred, pad the possible range
|
|
# by up to this amount before and after our best guess
|
|
ESTIMATED_TIME_PADDING = 5
|
|
|
|
messages_received = prom.Counter(
|
|
"messages_received",
|
|
"Number of chat messages recieved by the client. 'client' tag is per client instance.",
|
|
["channel", "client", "command"],
|
|
)
|
|
|
|
messages_ignored = prom.Counter(
|
|
"messages_ignored",
|
|
"Number of chat messages that were recieved but ignored for some reason (see reason label)",
|
|
["client", "command", "reason"],
|
|
)
|
|
|
|
messages_written = prom.Counter(
|
|
"messages_written",
|
|
"Number of chat messages recieved and then written out to disk in a batch.",
|
|
["channel", "client", "command"],
|
|
)
|
|
|
|
batch_messages = prom.Histogram(
|
|
"batch_messages",
|
|
"Number of messages in batches written to disk",
|
|
["channel", "client"],
|
|
buckets=[0, 1, 4, 16, 64, 256, 1024],
|
|
)
|
|
|
|
# based on DB2021, an average PRIVMSG is about 600 bytes.
|
|
# so since batch_messages goes up to 1024, batch_bytes should go up to ~ 600KB.
|
|
# let's just call it 1MB.
|
|
batch_bytes = prom.Histogram(
|
|
"batch_bytes",
|
|
"Size in bytes of batches written to disk",
|
|
["channel", "client"],
|
|
buckets=[0, 256, 1024, 4096, 16384, 65536, 262144, 1048576]
|
|
)
|
|
|
|
open_batches = prom.Gauge(
|
|
"open_batches",
|
|
"Number of batches that have at least one pending message not yet written to disk",
|
|
["channel", "client"],
|
|
)
|
|
|
|
server_lag = prom.Gauge(
|
|
"server_lag",
|
|
"Estimated time difference between server-side timestamps and local time, based on latest message",
|
|
["channel", "client"],
|
|
)
|
|
|
|
merge_pass_duration = prom.Histogram(
|
|
"merge_pass_duration",
|
|
"How long it took to run through all batches and merge any duplicates",
|
|
)
|
|
merge_pass_merges = prom.Histogram(
|
|
"merge_pass_merges",
|
|
"How many merges (times for which more than one batch existed) were done in a single merge pass",
|
|
buckets=[0, 1, 10, 100, 1000, 10000],
|
|
)
|
|
|
|
class Archiver(object):
|
|
def __init__(self, name, base_dir, channels, nick, oauth_token, download_media):
|
|
self.logger = logging.getLogger(type(self).__name__).getChild(name)
|
|
self.name = name
|
|
self.messages = gevent.queue.Queue()
|
|
self.channels = channels
|
|
self.base_dir = base_dir
|
|
self.download_media = download_media
|
|
|
|
self.stopping = gevent.event.Event()
|
|
self.got_reconnect = gevent.event.Event()
|
|
self.client = Client(
|
|
hostname='irc.chat.twitch.tv',
|
|
port=6697,
|
|
ssl=True,
|
|
nick=nick,
|
|
password=oauth_token,
|
|
twitch=True,
|
|
stop_handler=lambda c: self.stopping.set(),
|
|
)
|
|
for channel in self.channels:
|
|
self.client.channel('#{}'.format(channel)).join()
|
|
|
|
def channel_path(self, channel):
|
|
return os.path.join(self.base_dir, channel, "chat")
|
|
|
|
def write_batch(self, channel, batch_time, messages):
|
|
# wrapper around general write_batch() function
|
|
write_batch(
|
|
self.channel_path(channel), batch_time, messages,
|
|
size_histogram=batch_bytes.labels(channel=channel, client=self.name),
|
|
)
|
|
batch_messages.labels(channel=channel, client=self.name).observe(len(messages))
|
|
# incrementing a prom counter can be stupidly expensive, collect up per-command values
|
|
# so we can do them in one go
|
|
by_command = defaultdict(lambda: 0)
|
|
for message in messages:
|
|
by_command[message["command"]] += 1
|
|
for command, count in by_command.items():
|
|
messages_written.labels(channel=channel, client=self.name, command=command).inc(count)
|
|
|
|
def run(self):
|
|
@self.client.handler(sync=True)
|
|
def handle_message(client, message):
|
|
self.messages.put(message)
|
|
|
|
# Twitch sends a RECONNECT shortly before terminating the connection from the server side.
|
|
# This gives us time to start up a new instance of the archiver while keeping this one
|
|
# running, so that we can be sure we don't miss anything. This will cause duplicate batches,
|
|
# but those will get merged later.
|
|
@self.client.handler(command='RECONNECT')
|
|
def handle_reconnect(client, message):
|
|
self.got_reconnect.set()
|
|
|
|
self.client.start()
|
|
|
|
last_server_time = None
|
|
last_timestamped_message = None
|
|
# {(channel, batch time): [messages]}
|
|
batches = {}
|
|
for channel in self.channels:
|
|
open_batches.labels(channel=channel, client=self.name).set_function(
|
|
lambda: len([1 for c, t in batches if c == channel])
|
|
)
|
|
|
|
# Tracks if we've seen the initial ROOMSTATE for each channel we've joined.
|
|
# Everything up to and including this message is per-connection:
|
|
# - a JOIN for us joining the room (even if we were already there on another connection)
|
|
# - a USERSTATE for our user
|
|
# - a ROOMSTATE for the room
|
|
# We ignore all messages before the initial ROOMSTATE.
|
|
initialized_channels = set()
|
|
|
|
while not (self.stopping.is_set() and self.messages.empty()):
|
|
# wait until we either have a message, are stopping, or a batch can be closed
|
|
if batches:
|
|
oldest_batch_time = min(batch_time for channel, batch_time in batches.keys())
|
|
next_batch_close = oldest_batch_time + BATCH_INTERVAL + MAX_SERVER_LAG
|
|
self.logger.debug("Next batch close at {} (batch times: {})".format(next_batch_close, list(batches.keys())))
|
|
timeout = max(0, next_batch_close - time.time())
|
|
else:
|
|
timeout = None
|
|
self.logger.debug("Waiting up to {} for message or stop".format(timeout))
|
|
gevent.wait([gevent.spawn(self.messages.peek), self.stopping], count=1, timeout=timeout)
|
|
|
|
# close any closable batches
|
|
now = time.time()
|
|
for (channel, batch_time), messages in list(batches.items()):
|
|
if now >= batch_time + BATCH_INTERVAL + MAX_SERVER_LAG:
|
|
del batches[channel, batch_time]
|
|
self.write_batch(channel, batch_time, messages)
|
|
|
|
# consume a message if any
|
|
try:
|
|
message = self.messages.get(block=False)
|
|
except gevent.queue.Empty:
|
|
continue
|
|
|
|
self.logger.debug("Got message: {}".format(message))
|
|
|
|
if message.command not in COMMANDS:
|
|
self.logger.info("Skipping non-whitelisted command: {}".format(message.command))
|
|
messages_ignored.labels(client=self.name, command=message.command, reason="non-whitelisted").inc()
|
|
continue
|
|
|
|
# For all message types we capture, the channel name is always the first param.
|
|
if not message.params:
|
|
self.logger.error(f"Skipping malformed message with no params - cannot determine channel: {message}")
|
|
messages_ignored.labels(client=self.name, command=message.command, reason="no-channel").inc()
|
|
continue
|
|
|
|
channel = message.params[0].lstrip("#")
|
|
|
|
if channel not in self.channels:
|
|
self.logger.error(f"Skipping unexpected message for unrequested channel {channel}")
|
|
messages_ignored.labels(client=self.name, command=message.command, reason="bad-channel").inc()
|
|
continue
|
|
|
|
if channel not in initialized_channels:
|
|
self.logger.debug(f"Skipping {message.command} message on non-initialized channel {channel}")
|
|
if message.command == "ROOMSTATE":
|
|
initialized_channels.add(channel)
|
|
self.logger.info(f"Channel {channel} is ready")
|
|
messages_ignored.labels(client=self.name, command=message.command, reason="non-initialized-channel").inc()
|
|
continue
|
|
|
|
data = {
|
|
attr: getattr(message, attr)
|
|
for attr in ('command', 'params', 'sender', 'user', 'host', 'tags')
|
|
}
|
|
data['receivers'] = {self.name: message.received_at}
|
|
self.logger.debug("Got message data: {}".format(data))
|
|
messages_received.labels(channel=channel, client=self.name, command=message.command).inc()
|
|
|
|
if data['tags'] and data['tags'].get('emotes', '') != '':
|
|
emote_specs = data['tags']['emotes'].split('/')
|
|
emote_ids = [emote_spec.split(':')[0] for emote_spec in emote_specs]
|
|
ensure_emotes(self.base_dir, emote_ids)
|
|
|
|
if self.download_media and data['command'] == "PRIVMSG" and len(data["params"]) == 2:
|
|
ensure_image_links(self.base_dir, data["params"][1])
|
|
|
|
if data['tags'] and 'tmi-sent-ts' in data['tags']:
|
|
# explicit server time is available
|
|
timestamp = int(data['tags']['tmi-sent-ts']) / 1000. # original is int ms
|
|
last_timestamped_message = message
|
|
last_server_time = timestamp
|
|
server_lag.labels(channel=channel, client=self.name).set(time.time() - timestamp)
|
|
time_range = 0
|
|
self.logger.debug("Message has exact timestamp: {}".format(timestamp))
|
|
# check for any non-timestamped messages which we now know must have been
|
|
# before this message. We need to check this batch and the previous.
|
|
batch_time = int(timestamp / BATCH_INTERVAL) * BATCH_INTERVAL
|
|
for batch in (batch_time, batch_time - BATCH_INTERVAL):
|
|
for msg in batches.get((channel, batch), []):
|
|
time_between = timestamp - msg['time']
|
|
if 0 < time_between < msg['time_range']:
|
|
self.logger.debug("Updating previous message {m[command]}@{m[time]} range {m[time_range]} -> {new}".format(
|
|
m=msg, new=time_between,
|
|
))
|
|
msg['time_range'] = time_between
|
|
elif last_server_time is not None:
|
|
# estimate current server time based on time since last timestamped message
|
|
est_server_time = last_server_time + time.time() - last_timestamped_message.received_at
|
|
# pad either side of the estimated server time, use this as a baseline
|
|
timestamp = est_server_time - ESTIMATED_TIME_PADDING
|
|
time_range = 2 * ESTIMATED_TIME_PADDING
|
|
# if previously timestamped message falls within this range, we know this message
|
|
# came after it
|
|
timestamp = max(timestamp, last_server_time)
|
|
else:
|
|
# we have no idea what the server time is, so we guess as 2x the normal padding
|
|
# starting from local time.
|
|
timestamp = time.time() - 2 * ESTIMATED_TIME_PADDING
|
|
time_range = 3 * ESTIMATED_TIME_PADDING
|
|
|
|
if data['command'] in DELAYED_COMMANDS:
|
|
# might have happened MAX_DELAY sooner than otherwise indicated.
|
|
timestamp -= MAX_DELAY
|
|
time_range += MAX_DELAY
|
|
|
|
self.logger.debug("Message time determined as {} + up to {}".format(timestamp, time_range))
|
|
data['time'] = timestamp
|
|
data['time_range'] = time_range
|
|
batch_time = int(timestamp / BATCH_INTERVAL) * BATCH_INTERVAL
|
|
batches.setdefault((channel, batch_time), []).append(data)
|
|
|
|
# Close any remaining batches
|
|
for (channel, batch_time), messages in batches.items():
|
|
self.write_batch(channel, batch_time, messages)
|
|
|
|
self.client.wait_for_stop() # re-raise any errors
|
|
self.logger.info("Client stopped")
|
|
|
|
def stop(self):
|
|
self.client.stop()
|
|
|
|
|
|
_EMOTES_RUNNING = KeyedGroup()
|
|
def ensure_emotes(base_dir, emote_ids):
|
|
"""Tries to download given emote from twitch if it doesn't already exist.
|
|
This happens in the background and errors are ignored.
|
|
"""
|
|
def _ensure_emote(emote_id, theme, scale):
|
|
url = "https://static-cdn.jtvnw.net/emoticons/v2/{}/default/{}/{}".format(emote_id, theme, scale)
|
|
path = os.path.join(base_dir, "emotes", emote_id, "{}-{}".format(theme, scale))
|
|
if os.path.exists(path):
|
|
logging.debug("Emote {} already exists".format(path))
|
|
return
|
|
logging.info("Fetching emote from {}".format(url))
|
|
try:
|
|
response = requests.get(url)
|
|
except Exception:
|
|
logging.warning("Exception while fetching emote from {}".format(url), exc_info=True)
|
|
return
|
|
if not response.ok:
|
|
logging.warning("Error {} while fetching emote from {}: {}".format(response.status_code, url, response.text))
|
|
return
|
|
atomic_write(path, response.content)
|
|
logging.info("Saved emote {}".format(path))
|
|
|
|
for emote_id in emote_ids:
|
|
for theme in ('light', 'dark'):
|
|
for scale in ('1.0', '2.0', '3.0'):
|
|
# to prevent downloading the same emote twice because the first download isn't finished yet,
|
|
# use a KeyedGroup.
|
|
key = base_dir, emote_id, theme, scale
|
|
_EMOTES_RUNNING.spawn(key, _ensure_emote, emote_id, theme, scale)
|
|
|
|
|
|
def wait_for_ensure_emotes():
|
|
_EMOTES_RUNNING.wait()
|
|
|
|
|
|
URL_REGEX = re.compile(r"""
|
|
# Previous char is not a letter. This prevents eg. "foohttp://example.com"
|
|
# Also disallows / as the previous character, otherwise "file:///foo.bar/baz"
|
|
# can match on the "foo.bar/baz" part.
|
|
(?<! [\w/] )
|
|
# optional scheme, which must be http or https (we don't want other schemes)
|
|
(?P<scheme> https?:// )?
|
|
# Hostname, which must contain a dot. Single-part hostnames like "localhost" are valid
|
|
# but we don't want to match them, and this avoids cases like "yes/no" matching.
|
|
# We enforce that the TLD is not fully numeric. No TLDs currently look like this
|
|
# (though it does end up forbidding raw IPv4 addresses), and a common false-positive
|
|
# is "1.5/10" or similar.
|
|
( [a-z0-9-]+ \. )+ [a-z][a-z0-9-]+
|
|
# Optional port
|
|
( : [0-9]+ )?
|
|
# Optional path. We assume a path character can be anything that's not completely disallowed
|
|
# but don't try to parse it further into query, fragment etc.
|
|
# We also include all unicode characters considered "letters" since it's likely someone might
|
|
# put a ö or something in a path and copy-paste it from their browser URL bar which renders it
|
|
# like that even though it's encoded when actually sent as a URL.
|
|
# Restricting this to letters prevents things like non-breaking spaces causing problems.
|
|
# For the same reason we also allow {} and [] which seem to show up often in paths.
|
|
(?P<path> / [\w!#$%&'()*+,./:;=?@_~{}\[\]-]* )?
|
|
""", re.VERBOSE | re.IGNORECASE)
|
|
|
|
|
|
_IMAGE_LINKS_RUNNING = KeyedGroup()
|
|
def ensure_image_links(base_dir, text):
|
|
"""Find any image or video links in the text and download them if we don't have them already.
|
|
This happens in the background and errors are ignored."""
|
|
media_dir = os.path.join(base_dir, "media")
|
|
|
|
def get_url(url):
|
|
try:
|
|
try:
|
|
download_media(url, media_dir)
|
|
except FailedResponse:
|
|
# We got a 404 or similar.
|
|
# Attempt to remove any stray punctuation from the url and try again.
|
|
# We only try this once.
|
|
if url.endswith("..."):
|
|
url = url[:-3]
|
|
elif not url[-1].isalnum():
|
|
url = url[:-1]
|
|
else:
|
|
# No puncuation found, let the original result stand
|
|
raise
|
|
download_media(url, media_dir)
|
|
except WrongContent as e:
|
|
logging.info(f"Ignoring non-media link {url}: {e}")
|
|
except Rejected as e:
|
|
logging.warning(f"Rejected dangerous link {url}: {e}")
|
|
except Exception:
|
|
logging.warning(f"Unable to fetch link {url}", exc_info=True)
|
|
|
|
for match in URL_REGEX.finditer(text):
|
|
# Don't match on bare hostnames with no scheme AND no path. ie.
|
|
# http://example.com SHOULD match
|
|
# example.com/foo SHOULD match
|
|
# example.com SHOULD NOT match
|
|
# Otherwise we get a false positive every time someone says "fart.wav" or similar.
|
|
if match.group("scheme") is None and match.group("path") is None:
|
|
continue
|
|
url = match.group(0)
|
|
key = (media_dir, url)
|
|
_IMAGE_LINKS_RUNNING.spawn(key, get_url, url)
|
|
|
|
|
|
def write_batch(path, batch_time, messages, size_histogram=None):
|
|
"""Batches are named PATH/YYYY-MM-DDTHH/MM:SS-HASH.json"""
|
|
output = (format_batch(messages) + '\n').encode('utf-8')
|
|
if size_histogram is not None:
|
|
size_histogram.observe(len(output))
|
|
hash = base64.b64encode(hashlib.sha256(output).digest(), b"-_").decode().rstrip("=")
|
|
hour = datetime.utcfromtimestamp(batch_time).strftime("%Y-%m-%dT%H")
|
|
time = datetime.utcfromtimestamp(batch_time).strftime("%M:%S")
|
|
filename = os.path.join(hour, "{}-{}.json".format(time, hash))
|
|
filepath = os.path.join(path, filename)
|
|
if os.path.exists(filepath):
|
|
logging.debug("Not writing batch {} - already exists.".format(filename))
|
|
else:
|
|
atomic_write(filepath, output)
|
|
logging.info("Wrote batch {}".format(filepath))
|
|
return filepath
|
|
|
|
|
|
def merge_all(path, interval=None, stopping=None):
|
|
"""Repeatedly scans the batch directory for batch files with the same batch time, and merges them.
|
|
By default, returns once it finds no duplicate files.
|
|
If interval is given, re-scans after that number of seconds.
|
|
If a gevent.event.Event() is passed in as stopping arg, returns when that event is set.
|
|
"""
|
|
if stopping is None:
|
|
# nothing will ever set this, but it's easier to not special-case it everywhere
|
|
stopping = gevent.event.Event()
|
|
while not stopping.is_set():
|
|
start = monotonic()
|
|
merges_done = 0
|
|
# loop until no changes
|
|
while True:
|
|
logging.debug("Scanning for merges")
|
|
by_time = {}
|
|
for hour in listdir(path):
|
|
for name in listdir(os.path.join(path, hour)):
|
|
if not name.endswith(".json"):
|
|
continue
|
|
min_sec = name.split("-")[0]
|
|
timestamp = "{}:{}".format(hour, min_sec)
|
|
by_time[timestamp] = by_time.get(timestamp, 0) + 1
|
|
if not any(count > 1 for timestamp, count in by_time.items()):
|
|
logging.info("All batches are merged")
|
|
break
|
|
for timestamp, count in by_time.items():
|
|
if count > 1:
|
|
logging.info("Merging {} batches at time {}".format(count, timestamp))
|
|
batch_time = timegm(time.strptime(timestamp, "%Y-%m-%dT%H:%M:%S"))
|
|
merge_batch_files(path, batch_time)
|
|
merges_done += 1
|
|
duration = monotonic() - start
|
|
merge_pass_duration.observe(duration)
|
|
merge_pass_merges.observe(merges_done)
|
|
if interval is None:
|
|
# one-shot
|
|
break
|
|
remaining = interval - duration
|
|
if remaining > 0:
|
|
logging.debug("Waiting {}s for next merge scan".format(remaining))
|
|
stopping.wait(remaining)
|
|
|
|
|
|
def merge_batch_files(path, batch_time):
|
|
"""For the given batch time, merges all the following messages:
|
|
- From batch files at that time
|
|
- From batch files for the previous batch time
|
|
- From batch files for the following batch time
|
|
and writes up to 3 batch files (one for each time) to replace them.
|
|
"""
|
|
# A note on race conditions:
|
|
# Suppose two processes attempt to merge the same batch at the same time.
|
|
# The critical section consists of:
|
|
# 1. Reading the old batch files
|
|
# 2. Writing the new batch files
|
|
# 3. Deleting the old batch files
|
|
# Crucially, we don't delete any data until we've written a replacement,
|
|
# and we don't delete any data that we didn't just incorporate into a new file.
|
|
# This might cause doubling up of data, eg. version A -> version B but also
|
|
# version A -> version C, but the end result will be that both B and C exist
|
|
# and will then be merged later.
|
|
|
|
messages = []
|
|
batch_files = [
|
|
batch_file
|
|
for time in [batch_time, batch_time - BATCH_INTERVAL, batch_time + BATCH_INTERVAL]
|
|
for batch_file in get_batch_files(path, time)
|
|
]
|
|
for batch_file in batch_files:
|
|
with open(batch_file) as f:
|
|
batch = f.read()
|
|
batch = [json.loads(line) for line in batch.strip().split("\n")]
|
|
messages = merge_messages(messages, batch)
|
|
|
|
by_time = {}
|
|
for message in messages:
|
|
batch_time = int(message['time'] / BATCH_INTERVAL) * BATCH_INTERVAL
|
|
by_time.setdefault(batch_time, []).append(message)
|
|
|
|
written = []
|
|
for batch_time, batch in by_time.items():
|
|
written.append(write_batch(path, batch_time, batch))
|
|
|
|
for batch_file in batch_files:
|
|
# don't delete something we just (re-)wrote
|
|
if batch_file not in written:
|
|
os.remove(batch_file)
|
|
|
|
|
|
def main(nick, oauth_token_path, *channels, base_dir='/mnt', name=None, merge_interval=60, metrics_port=8008, download_media=False):
|
|
with open(oauth_token_path) as f:
|
|
oauth_token = f.read()
|
|
# To ensure uniqueness even if multiple instances are running on the same host,
|
|
# also include a random slug
|
|
if name is None:
|
|
name = socket.gethostname()
|
|
slug = ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(5))
|
|
name = "{}.{}".format(name, slug)
|
|
|
|
stopping = gevent.event.Event()
|
|
gevent.signal_handler(signal.SIGTERM, stopping.set)
|
|
|
|
mergers = [
|
|
gevent.spawn(merge_all,
|
|
os.path.join(base_dir, channel, "chat"),
|
|
interval=merge_interval,
|
|
stopping=stopping
|
|
) for channel in channels
|
|
]
|
|
|
|
prom.start_http_server(metrics_port)
|
|
|
|
logging.info("Starting")
|
|
for index in count():
|
|
# To ensure uniqueness between clients, include a client number
|
|
archiver = Archiver("{}.{}".format(name, index), base_dir, channels, nick, oauth_token, download_media)
|
|
archive_worker = gevent.spawn(archiver.run)
|
|
workers = mergers + [archive_worker]
|
|
# wait for either graceful exit, error, or for a signal from the archiver that a reconnect was requested
|
|
gevent.wait([stopping, archiver.got_reconnect] + workers, count=1)
|
|
if stopping.is_set():
|
|
archiver.stop()
|
|
for worker in workers:
|
|
worker.get()
|
|
break
|
|
# if got reconnect, discard the old archiver (we don't care even if it fails after this)
|
|
# and make a new one
|
|
if archiver.got_reconnect.is_set():
|
|
logging.info("Got RECONNECT, creating new client while waiting for old one to finish")
|
|
continue
|
|
# the only remaining case is that something failed. stop everything and re-raise.
|
|
logging.warning("Stopping due to worker dying")
|
|
stopping.set()
|
|
archiver.stop()
|
|
for worker in workers:
|
|
worker.join()
|
|
# at least one of these should raise
|
|
for worker in workers:
|
|
worker.get()
|
|
assert False, "Worker unexpectedly exited successfully"
|
|
logging.info("Gracefully stopped")
|