chat-archiver: File merging and other fixes

pull/300/head
Mike Lang 3 years ago committed by Mike Lang
parent 0756539b85
commit d32cbbb7e1

@ -17,8 +17,9 @@ from common import ensure_directory
from girc import Client
# How long each batch is
BATCH_INTERVAL = 60
class Archiver(object):
# These are known to arrive up to 10s after their actual time
DELAYED_COMMANDS = [
"JOIN",
@ -36,9 +37,6 @@ class Archiver(object):
"USERSTATE",
]
# How long each batch is
BATCH_INTERVAL = 60
# Assume we're never more than this amount of time behind the server time
# Worst case if too low: multiple output files for same batch that need merging later
MAX_SERVER_LAG = 30
@ -47,6 +45,8 @@ class Archiver(object):
# by up to this amount before and after our best guess
ESTIMATED_TIME_PADDING = 5
class Archiver(object):
def __init__(self, name, base_dir, channel, nick, oauth_token):
self.logger = logging.getLogger(type(self).__name__).getChild(channel)
self.name = name
@ -85,7 +85,7 @@ class Archiver(object):
while not self.stopping.is_set():
# wait until we either have a message, are stopping, or a batch can be closed
if batches:
next_batch_close = min(batches.keys()) + self.BATCH_INTERVAL + self.MAX_SERVER_LAG
next_batch_close = min(batches.keys()) + BATCH_INTERVAL + MAX_SERVER_LAG
self.logger.debug("Next batch close at {} (batch times: {})".format(next_batch_close, batches.keys()))
timeout = max(0, next_batch_close - time.time())
else:
@ -96,9 +96,9 @@ class Archiver(object):
# close any closable batches
now = time.time()
for batch_time, messages in list(batches.items()):
if now >= batch_time + self.BATCH_INTERVAL + self.MAX_SERVER_LAG:
if now >= batch_time + BATCH_INTERVAL + MAX_SERVER_LAG:
del batches[batch_time]
self.write_batch(batch_time, messages)
write_batch(self.path, batch_time, messages)
# consume a message if any
try:
@ -106,7 +106,7 @@ class Archiver(object):
except gevent.queue.Empty:
continue
if message.command not in self.COMMANDS:
if message.command not in COMMANDS:
self.logger.info("Skipping non-whitelisted command: {}".format(message.command))
continue
@ -127,8 +127,8 @@ class Archiver(object):
self.logger.debug("Message has exact timestamp: {}".format(timestamp))
# check for any non-timestamped messages which we now know must have been
# before this message. We need to check this batch and the previous.
batch_time = int(timestamp / self.BATCH_INTERVAL) * self.BATCH_INTERVAL
for batch in (batch_time, batch_time - self.BATCH_INTERVAL):
batch_time = int(timestamp / BATCH_INTERVAL) * BATCH_INTERVAL
for batch in (batch_time, batch_time - BATCH_INTERVAL):
for msg in batches.get(batch, []):
time_between = timestamp - msg['time']
if 0 < time_between < msg['time_range']:
@ -140,18 +140,18 @@ class Archiver(object):
# estimate current server time based on time since last timestamped message
est_server_time = last_server_time + time.time() - last_timestamped_message.received_at
# pad either side of the estimated server time, use this as a baseline
timestamp = est_server_time - self.ESTIMATED_TIME_PADDING
time_range = 2 * self.ESTIMATED_TIME_PADDING
timestamp = est_server_time - ESTIMATED_TIME_PADDING
time_range = 2 * ESTIMATED_TIME_PADDING
# if previously timestamped message falls within this range, we know this message
# came after it
timestamp = max(timestamp, last_server_time)
else:
# we have no idea what the server time is, so we guess as 2x the normal padding
# starting from local time.
timestamp = time.time() - 2 * self.ESTIMATED_TIME_PADDING
time_range = 3 * self.ESTIMATED_TIME_PADDING
timestamp = time.time() - 2 * ESTIMATED_TIME_PADDING
time_range = 3 * ESTIMATED_TIME_PADDING
if data['command'] in self.DELAYED_COMMANDS:
if data['command'] in DELAYED_COMMANDS:
# might have happened 10s sooner than otherwise indicated.
timestamp -= 10
time_range += 10
@ -159,16 +159,34 @@ class Archiver(object):
self.logger.debug("Message time determined as {} + up to {}".format(timestamp, time_range))
data['time'] = timestamp
data['time_range'] = time_range
batch_time = int(timestamp / self.BATCH_INTERVAL) * self.BATCH_INTERVAL
batch_time = int(timestamp / BATCH_INTERVAL) * BATCH_INTERVAL
batches.setdefault(batch_time, []).append(data)
# Close any remaining batches
for batch_time, messages in batches.items():
self.write_batch(batch_time, messages)
write_batch(self.path, batch_time, messages)
self.client.wait_for_stop() # re-raise any errors
def write_batch(self, batch_time, messages):
def stop(self):
self.client.stop()
def write_batch(path, batch_time, messages):
output = (format_batch(messages) + '\n').encode('utf-8')
hash = base64.b64encode(hashlib.sha256(output).digest(), b"-_").decode().rstrip("=")
time = datetime.utcfromtimestamp(batch_time).strftime("%Y-%m-%dT%H:%M:%S")
filename = "{}-{}.json".format(time, hash)
filepath = os.path.join(path, filename)
temppath = "{}.{}.temp".format(filepath, uuid4())
ensure_directory(filepath)
with open(temppath, 'wb') as f:
f.write(output)
os.rename(temppath, filepath)
logging.info("Wrote batch {}".format(filepath))
def format_batch(messages):
# We need to take some care to have a consistent ordering and format here.
# We use a "canonicalised JSON" format, which is really just whatever the python encoder does,
# with compact separators.
@ -179,20 +197,152 @@ class Archiver(object):
# We sort by timestamp, then timestamp range, then if all else fails, lexiographically
# on the encoded representation.
messages.sort(key=lambda item: (item[0]['time'], item[0]['time_range'], item[1]))
output = ("\n".join(line for message, line in messages) + "\n").encode("utf-8")
hash = base64.b64encode(hashlib.sha256(output).digest(), b"-_").decode().rstrip("=")
time = datetime.utcfromtimestamp(batch_time).strftime("%Y-%m-%dT%H:%M:%S")
filename = "{}-{}.json".format(time, hash)
filepath = os.path.join(self.path, filename)
temppath = "{}.{}.temp".format(filepath, uuid4())
ensure_directory(filepath)
with open(temppath, 'wb') as f:
f.write(output)
os.rename(temppath, filepath)
self.logger.info("Wrote batch {}".format(filepath))
return "\n".join(line for message, line in messages)
def stop(self):
self.client.stop()
def get_batch_files(path, batch_time):
"""Returns list of batch filepaths for a given batch time"""
time = datetime.utcfromtimestamp(batch_time).strftime("%Y-%m-%dT%H:%M:%S-")
return [
os.path.join(path, name)
for name in os.listdir(path)
if name.startswith(time) and name.endswith(".json")
]
def merge_batch_files(path, batch_time):
"""For the given batch time, merges all the following messages:
- From batch files at that time
- From batch files for the previous batch time
- From batch files for the following batch time
and writes up to 3 batch files (one for each time) to replace them.
"""
# A note on race conditions:
# Suppose two processes attempt to merge the same batch at the same time.
# The critical section consists of:
# 1. Reading the old batch files
# 2. Writing the new batch files
# 3. Deleting the old batch files
# Crucially, we don't delete any data until we've written a replacement,
# and we don't delete any data that we didn't just incorporate into a new file.
# This might cause doubling up of data, eg. version A -> version B but also
# version A -> version C, but the end result will be that both B and C exist
# and will then be merged later.
messages = []
batch_files = [
batch_file
for time in [batch_time, batch_time - BATCH_INTERVAL, batch_time + BATCH_INTERVAL]
for batch_file in get_batch_files(path, time)
]
for batch_file in batch_files:
with open(batch_file) as f:
batch = f.read()
batch = [json.loads(line) for line in batch.strip().split("\n")]
messages = merge_messages(messages, batch)
# sorting by time is needed for group_by(), we'll sort properly on save.
messages.sort(key=lambda message: message['time'])
for batch_time, batch in itertools.group_by(messages, key=
lambda message: int(message['time'] / BATCH_INTERVAL) * BATCH_INTERVAL
):
write_batch(path, batch_time, batch)
for batch_file in batch_files:
os.remove(batch_file)
def merge_messages(left, right):
"""Merges two lists of messages into one merged list.
This operation should be a CRDT, ie. all the following hold:
- associative: merge(merge(A, B), C) == merge(A, merge(B, C))
- commutitive: merge(A, B) == merge(B, A)
- reflexive: merge(A, A) == A
This means that no matter what order information from different sources
is incorporated (or if sources are repeated), the results should be the same.
"""
# An optimization - if either size is empty, return the other side without processing.
if not left:
return right
if not right:
return left
# Calculates intersection of time range of both messages, or None if they don't overlap
def overlap(a, b):
range_start = max(a['time'], b['time'])
range_end = min(a['time'] + a['time_range'], b['time'] + b['time_range'])
if range_end < range_start:
return None
return range_start, range_end - range_start
# Returns merged message if two messages are compatible with being the same message,
# or else None.
def merge_message(a, b):
o = overlap(a, b)
if o and all(
a.get(k) == b.get(k)
for k in set(a.keys()) | set(b.keys())
if k not in ("receivers", "time", "time_range")
):
receivers = a["receivers"] | b["receivers"]
# Error checkdng - make sure no receiver timestamps are being overwritten.
# This would indicate we're merging two messages recieved at different times
# by the same recipient.
for k in receivers.keys():
for old in (a, b):
if k in old and old[k] != recievers[k]:
raise ValueError(f"Merge would merge two messages with different recipient timestamps: {a}, {b}")
return a | {
"time": o[0],
"time_range": o[1],
"receivers": receivers,
}
return None
# Match things with identical ids first, and collect unmatched into left and right lists
by_id = {}
unmatched = [], []
for messages, u in zip((left, right), unmatched):
for message in messages:
id = message.get('tags', {}).get('id')
if id:
by_id.setdefault(id, []).append(message)
else:
u.append(message)
result = []
for id, messages in by_id.items():
if len(messages) == 1:
logging.debug(f"Message with id {id} has no match")
result.append(messages[0])
else:
merged = merge_message(*messages)
if merged is None:
raise ValueError(f"Got two non-matching messages with id {id}: {messages[0]}, {messages[1]}")
logging.debug(f"Merged messages with id {id}")
result.append(merged)
# For time-range messages, pair off each one in left with first match in right,
# and pass through anything with no matches.
left_unmatched, right_unmatched = unmatched
for message in left_unmatched:
for other in right_unmatched:
merged = merge_message(message, other)
if merged:
logging.debug(
"Matched {m[command]} message {a[time]}+{a[time_range]} & {b[time]}+{b[time_range]} -> {m[time]}+{m[time_range]}"
.format(a=message, b=other, m=merged)
)
right_unmatched.remove(other)
result.append(merged)
break
else:
logging.debug("No match found for {m[command]} at {m[time]}+{m[time_range]}".format(m=message))
result.append(message)
for message in right_unmatched:
logging.debug("No match found for {m[command]} at {m[time]}+{m[time_range]}".format(m=message))
result.append(message)
return result
def main(channel, nick, oauth_token_path, base_dir='/mnt'):

@ -0,0 +1,20 @@
import argh
import logging
import json
from .main import merge_messages, format_batch
def main(*paths, log='INFO'):
"""Merge all listed batch files and output result to stdout"""
logging.basicConfig(level=log)
messages = []
for path in paths:
with open(path) as f:
batch = f.read()
batch = [json.loads(line) for line in batch.strip().split("\n")]
messages = merge_messages(messages, batch)
print(format_batch(messages))
if __name__ == '__main__':
argh.dispatch_command(main)
Loading…
Cancel
Save