diff --git a/chat_archiver/chat_archiver/main.py b/chat_archiver/chat_archiver/main.py index db9c3a5..d112894 100644 --- a/chat_archiver/chat_archiver/main.py +++ b/chat_archiver/chat_archiver/main.py @@ -17,36 +17,36 @@ from common import ensure_directory from girc import Client +# How long each batch is +BATCH_INTERVAL = 60 -class Archiver(object): - # These are known to arrive up to 10s after their actual time - DELAYED_COMMANDS = [ - "JOIN", - "PART", - ] +# These are known to arrive up to 10s after their actual time +DELAYED_COMMANDS = [ + "JOIN", + "PART", +] - COMMANDS = DELAYED_COMMANDS + [ - "PRIVMSG", - "CLEARCHAT", - "CLEARMSG", - "HOSTTARGET", - "NOTICE", - "ROOMSTATE", - "USERNOTICE", - "USERSTATE", - ] +COMMANDS = DELAYED_COMMANDS + [ + "PRIVMSG", + "CLEARCHAT", + "CLEARMSG", + "HOSTTARGET", + "NOTICE", + "ROOMSTATE", + "USERNOTICE", + "USERSTATE", +] - # How long each batch is - BATCH_INTERVAL = 60 +# Assume we're never more than this amount of time behind the server time +# Worst case if too low: multiple output files for same batch that need merging later +MAX_SERVER_LAG = 30 - # Assume we're never more than this amount of time behind the server time - # Worst case if too low: multiple output files for same batch that need merging later - MAX_SERVER_LAG = 30 +# When guessing when a non-timestamped event occurred, pad the possible range +# by up to this amount before and after our best guess +ESTIMATED_TIME_PADDING = 5 - # When guessing when a non-timestamped event occurred, pad the possible range - # by up to this amount before and after our best guess - ESTIMATED_TIME_PADDING = 5 +class Archiver(object): def __init__(self, name, base_dir, channel, nick, oauth_token): self.logger = logging.getLogger(type(self).__name__).getChild(channel) self.name = name @@ -85,7 +85,7 @@ class Archiver(object): while not self.stopping.is_set(): # wait until we either have a message, are stopping, or a batch can be closed if batches: - next_batch_close = min(batches.keys()) + self.BATCH_INTERVAL + self.MAX_SERVER_LAG + next_batch_close = min(batches.keys()) + BATCH_INTERVAL + MAX_SERVER_LAG self.logger.debug("Next batch close at {} (batch times: {})".format(next_batch_close, batches.keys())) timeout = max(0, next_batch_close - time.time()) else: @@ -96,9 +96,9 @@ class Archiver(object): # close any closable batches now = time.time() for batch_time, messages in list(batches.items()): - if now >= batch_time + self.BATCH_INTERVAL + self.MAX_SERVER_LAG: + if now >= batch_time + BATCH_INTERVAL + MAX_SERVER_LAG: del batches[batch_time] - self.write_batch(batch_time, messages) + write_batch(self.path, batch_time, messages) # consume a message if any try: @@ -106,7 +106,7 @@ class Archiver(object): except gevent.queue.Empty: continue - if message.command not in self.COMMANDS: + if message.command not in COMMANDS: self.logger.info("Skipping non-whitelisted command: {}".format(message.command)) continue @@ -127,8 +127,8 @@ class Archiver(object): self.logger.debug("Message has exact timestamp: {}".format(timestamp)) # check for any non-timestamped messages which we now know must have been # before this message. We need to check this batch and the previous. - batch_time = int(timestamp / self.BATCH_INTERVAL) * self.BATCH_INTERVAL - for batch in (batch_time, batch_time - self.BATCH_INTERVAL): + batch_time = int(timestamp / BATCH_INTERVAL) * BATCH_INTERVAL + for batch in (batch_time, batch_time - BATCH_INTERVAL): for msg in batches.get(batch, []): time_between = timestamp - msg['time'] if 0 < time_between < msg['time_range']: @@ -140,18 +140,18 @@ class Archiver(object): # estimate current server time based on time since last timestamped message est_server_time = last_server_time + time.time() - last_timestamped_message.received_at # pad either side of the estimated server time, use this as a baseline - timestamp = est_server_time - self.ESTIMATED_TIME_PADDING - time_range = 2 * self.ESTIMATED_TIME_PADDING + timestamp = est_server_time - ESTIMATED_TIME_PADDING + time_range = 2 * ESTIMATED_TIME_PADDING # if previously timestamped message falls within this range, we know this message # came after it timestamp = max(timestamp, last_server_time) else: # we have no idea what the server time is, so we guess as 2x the normal padding # starting from local time. - timestamp = time.time() - 2 * self.ESTIMATED_TIME_PADDING - time_range = 3 * self.ESTIMATED_TIME_PADDING + timestamp = time.time() - 2 * ESTIMATED_TIME_PADDING + time_range = 3 * ESTIMATED_TIME_PADDING - if data['command'] in self.DELAYED_COMMANDS: + if data['command'] in DELAYED_COMMANDS: # might have happened 10s sooner than otherwise indicated. timestamp -= 10 time_range += 10 @@ -159,42 +159,192 @@ class Archiver(object): self.logger.debug("Message time determined as {} + up to {}".format(timestamp, time_range)) data['time'] = timestamp data['time_range'] = time_range - batch_time = int(timestamp / self.BATCH_INTERVAL) * self.BATCH_INTERVAL + batch_time = int(timestamp / BATCH_INTERVAL) * BATCH_INTERVAL batches.setdefault(batch_time, []).append(data) # Close any remaining batches for batch_time, messages in batches.items(): - self.write_batch(batch_time, messages) + write_batch(self.path, batch_time, messages) self.client.wait_for_stop() # re-raise any errors - def write_batch(self, batch_time, messages): - # We need to take some care to have a consistent ordering and format here. - # We use a "canonicalised JSON" format, which is really just whatever the python encoder does, - # with compact separators. - messages = [ - (message, json.dumps(message, separators=(',', ':'))) - for message in messages - ] - # We sort by timestamp, then timestamp range, then if all else fails, lexiographically - # on the encoded representation. - messages.sort(key=lambda item: (item[0]['time'], item[0]['time_range'], item[1])) - output = ("\n".join(line for message, line in messages) + "\n").encode("utf-8") - hash = base64.b64encode(hashlib.sha256(output).digest(), b"-_").decode().rstrip("=") - time = datetime.utcfromtimestamp(batch_time).strftime("%Y-%m-%dT%H:%M:%S") - filename = "{}-{}.json".format(time, hash) - filepath = os.path.join(self.path, filename) - temppath = "{}.{}.temp".format(filepath, uuid4()) - ensure_directory(filepath) - with open(temppath, 'wb') as f: - f.write(output) - os.rename(temppath, filepath) - self.logger.info("Wrote batch {}".format(filepath)) - def stop(self): self.client.stop() +def write_batch(path, batch_time, messages): + output = (format_batch(messages) + '\n').encode('utf-8') + hash = base64.b64encode(hashlib.sha256(output).digest(), b"-_").decode().rstrip("=") + time = datetime.utcfromtimestamp(batch_time).strftime("%Y-%m-%dT%H:%M:%S") + filename = "{}-{}.json".format(time, hash) + filepath = os.path.join(path, filename) + temppath = "{}.{}.temp".format(filepath, uuid4()) + ensure_directory(filepath) + with open(temppath, 'wb') as f: + f.write(output) + os.rename(temppath, filepath) + logging.info("Wrote batch {}".format(filepath)) + + +def format_batch(messages): + # We need to take some care to have a consistent ordering and format here. + # We use a "canonicalised JSON" format, which is really just whatever the python encoder does, + # with compact separators. + messages = [ + (message, json.dumps(message, separators=(',', ':'))) + for message in messages + ] + # We sort by timestamp, then timestamp range, then if all else fails, lexiographically + # on the encoded representation. + messages.sort(key=lambda item: (item[0]['time'], item[0]['time_range'], item[1])) + return "\n".join(line for message, line in messages) + + +def get_batch_files(path, batch_time): + """Returns list of batch filepaths for a given batch time""" + time = datetime.utcfromtimestamp(batch_time).strftime("%Y-%m-%dT%H:%M:%S-") + return [ + os.path.join(path, name) + for name in os.listdir(path) + if name.startswith(time) and name.endswith(".json") + ] + + +def merge_batch_files(path, batch_time): + """For the given batch time, merges all the following messages: + - From batch files at that time + - From batch files for the previous batch time + - From batch files for the following batch time + and writes up to 3 batch files (one for each time) to replace them. + """ + # A note on race conditions: + # Suppose two processes attempt to merge the same batch at the same time. + # The critical section consists of: + # 1. Reading the old batch files + # 2. Writing the new batch files + # 3. Deleting the old batch files + # Crucially, we don't delete any data until we've written a replacement, + # and we don't delete any data that we didn't just incorporate into a new file. + # This might cause doubling up of data, eg. version A -> version B but also + # version A -> version C, but the end result will be that both B and C exist + # and will then be merged later. + + messages = [] + batch_files = [ + batch_file + for time in [batch_time, batch_time - BATCH_INTERVAL, batch_time + BATCH_INTERVAL] + for batch_file in get_batch_files(path, time) + ] + for batch_file in batch_files: + with open(batch_file) as f: + batch = f.read() + batch = [json.loads(line) for line in batch.strip().split("\n")] + messages = merge_messages(messages, batch) + + # sorting by time is needed for group_by(), we'll sort properly on save. + messages.sort(key=lambda message: message['time']) + for batch_time, batch in itertools.group_by(messages, key= + lambda message: int(message['time'] / BATCH_INTERVAL) * BATCH_INTERVAL + ): + write_batch(path, batch_time, batch) + + for batch_file in batch_files: + os.remove(batch_file) + +def merge_messages(left, right): + """Merges two lists of messages into one merged list. + This operation should be a CRDT, ie. all the following hold: + - associative: merge(merge(A, B), C) == merge(A, merge(B, C)) + - commutitive: merge(A, B) == merge(B, A) + - reflexive: merge(A, A) == A + This means that no matter what order information from different sources + is incorporated (or if sources are repeated), the results should be the same. + """ + # An optimization - if either size is empty, return the other side without processing. + if not left: + return right + if not right: + return left + + # Calculates intersection of time range of both messages, or None if they don't overlap + def overlap(a, b): + range_start = max(a['time'], b['time']) + range_end = min(a['time'] + a['time_range'], b['time'] + b['time_range']) + if range_end < range_start: + return None + return range_start, range_end - range_start + + # Returns merged message if two messages are compatible with being the same message, + # or else None. + def merge_message(a, b): + o = overlap(a, b) + if o and all( + a.get(k) == b.get(k) + for k in set(a.keys()) | set(b.keys()) + if k not in ("receivers", "time", "time_range") + ): + receivers = a["receivers"] | b["receivers"] + # Error checkdng - make sure no receiver timestamps are being overwritten. + # This would indicate we're merging two messages recieved at different times + # by the same recipient. + for k in receivers.keys(): + for old in (a, b): + if k in old and old[k] != recievers[k]: + raise ValueError(f"Merge would merge two messages with different recipient timestamps: {a}, {b}") + return a | { + "time": o[0], + "time_range": o[1], + "receivers": receivers, + } + return None + + # Match things with identical ids first, and collect unmatched into left and right lists + by_id = {} + unmatched = [], [] + for messages, u in zip((left, right), unmatched): + for message in messages: + id = message.get('tags', {}).get('id') + if id: + by_id.setdefault(id, []).append(message) + else: + u.append(message) + + result = [] + for id, messages in by_id.items(): + if len(messages) == 1: + logging.debug(f"Message with id {id} has no match") + result.append(messages[0]) + else: + merged = merge_message(*messages) + if merged is None: + raise ValueError(f"Got two non-matching messages with id {id}: {messages[0]}, {messages[1]}") + logging.debug(f"Merged messages with id {id}") + result.append(merged) + + # For time-range messages, pair off each one in left with first match in right, + # and pass through anything with no matches. + left_unmatched, right_unmatched = unmatched + for message in left_unmatched: + for other in right_unmatched: + merged = merge_message(message, other) + if merged: + logging.debug( + "Matched {m[command]} message {a[time]}+{a[time_range]} & {b[time]}+{b[time_range]} -> {m[time]}+{m[time_range]}" + .format(a=message, b=other, m=merged) + ) + right_unmatched.remove(other) + result.append(merged) + break + else: + logging.debug("No match found for {m[command]} at {m[time]}+{m[time_range]}".format(m=message)) + result.append(message) + for message in right_unmatched: + logging.debug("No match found for {m[command]} at {m[time]}+{m[time_range]}".format(m=message)) + result.append(message) + + return result + + def main(channel, nick, oauth_token_path, base_dir='/mnt'): with open(oauth_token_path) as f: oauth_token = f.read() diff --git a/chat_archiver/chat_archiver/merge.py b/chat_archiver/chat_archiver/merge.py new file mode 100644 index 0000000..50a2abc --- /dev/null +++ b/chat_archiver/chat_archiver/merge.py @@ -0,0 +1,20 @@ + +import argh +import logging +import json + +from .main import merge_messages, format_batch + +def main(*paths, log='INFO'): + """Merge all listed batch files and output result to stdout""" + logging.basicConfig(level=log) + messages = [] + for path in paths: + with open(path) as f: + batch = f.read() + batch = [json.loads(line) for line in batch.strip().split("\n")] + messages = merge_messages(messages, batch) + print(format_batch(messages)) + +if __name__ == '__main__': + argh.dispatch_command(main)