diff --git a/buscribe/Dockerfile b/buscribe/Dockerfile new file mode 100644 index 0000000..23fff0d --- /dev/null +++ b/buscribe/Dockerfile @@ -0,0 +1,17 @@ +FROM debian:latest + +RUN apt update &&\ + apt install -y python3 libpq-dev python3-pip curl unzip ffmpeg + +COPY ../common /tmp/common +RUN pip install /tmp/common && rm -r /tmp/common + +COPY buscribe /tmp/buscribe +RUN pip install /tmp/buscribe && rm -r /tmp/buscribe && \ + mkdir /usr/share/buscribe && cd /usr/share/buscribe && \ + curl -LO http://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip && \ + unzip vosk-model-small-en-us-0.15.zip && rm vosk-model-small-en-us-0.15.zip && \ + curl -LO https://alphacephei.com/vosk/models/vosk-model-spk-0.4.zip && \ + unzip vosk-model-spk-0.4.zip && rm vosk-model-spk-0.4.zip + +ENTRYPOINT ["python3", "-m", "buscribe", "--base-dir", "/mnt"] diff --git a/buscribe/buscribe/__init__.py b/buscribe/buscribe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/buscribe/buscribe/__main__.py b/buscribe/buscribe/__main__.py new file mode 100644 index 0000000..f62ff0d --- /dev/null +++ b/buscribe/buscribe/__main__.py @@ -0,0 +1,12 @@ +import logging +import os + +import argh + +from buscribe.main import main + +LOG_FORMAT = "[%(asctime)s] %(levelname)8s %(name)s(%(module)s:%(lineno)d): %(message)s" + +level = os.environ.get('WUBLOADER_LOG_LEVEL', 'INFO').upper() +logging.basicConfig(level=level, format=LOG_FORMAT) +argh.dispatch_command(main) diff --git a/buscribe/buscribe/buscribe.py b/buscribe/buscribe/buscribe.py new file mode 100644 index 0000000..a4ed41c --- /dev/null +++ b/buscribe/buscribe/buscribe.py @@ -0,0 +1,98 @@ +import json +import logging +import subprocess +from datetime import timedelta, datetime + +from psycopg2._psycopg import cursor + +from buscribe.recognizer import BuscribeRecognizer + + +class HitMissingSegment(Exception): + pass + + +def transcribe_segments(segments: list, sample_rate: int, recognizer: BuscribeRecognizer, start_of_transcript: datetime, + db_cursor: cursor): + """Starts transcribing from a list of segments. + + Only starts committing new lines to the database after reaching start_of_transcript. + + The recognizer must be initialized to sample_rate and have start time set. + + Returns the end time of the last transcribed line.""" + + segments_end_time = segments[0].start + + for segment in segments: + + if segment is None: + return segments_end_time + + segments_end_time += segment.duration + + process = subprocess.Popen(['ffmpeg', + '-loglevel', 'quiet', + '-i', segment.path, + '-ar', str(sample_rate), + '-ac', '1', # TODO: Check for advanced downmixing + '-f', 's16le', '-'], + stdout=subprocess.PIPE) + while True: + data = process.stdout.read(16000) + if len(data) == 0: + break + if recognizer.AcceptWaveform(data): + result_json = json.loads(recognizer.Result()) + logging.debug(json.dumps(result_json, indent=2)) + + if result_json["text"] == "": + continue + + line_start_time = recognizer.segments_start_time + timedelta(seconds=result_json["result"][0]["start"]) + line_end_time = recognizer.segments_start_time + timedelta(seconds=result_json["result"][-1]["end"]) + + if line_start_time > start_of_transcript: + write_line(result_json, line_start_time, line_end_time, db_cursor) + + return segments_end_time + + +def write_line(line_json: dict, line_start_time: datetime, line_end_time: datetime, db_cursor): + """Commits line to the database""" + db_cursor.execute( + "INSERT INTO buscribe.public.buscribe_transcriptions(" + "start_time, " + "end_time, " + "transcription_line, " + "line_speaker, " + "transcription_json) VALUES (%s, %s ,%s, %s, %s)", + (line_start_time, + line_end_time, + line_json["text"], + line_json["spk"] if "spk" in line_json else None, + json.dumps(line_json) + ) + ) + + +def get_end_of_transcript(db_cursor): + """Grab the end timestamp of the current transcript. + + If there is no existing transcript returns default; used for cold starts.""" + db_cursor.execute("SELECT end_time FROM buscribe.public.buscribe_transcriptions ORDER BY end_time DESC LIMIT 1") + end_of_transcript_row = db_cursor.fetchone() + + return end_of_transcript_row.end_time if end_of_transcript_row is not None else None + + +def finish_off_recognizer(recognizer: BuscribeRecognizer, db_cursor): + """Flush the recognizer, commit the final line to the database and reset it.""" + final_result_json = json.loads(recognizer.FinalResult()) # Flush the tubes + + line_start_time = recognizer.segments_start_time + timedelta(seconds=final_result_json["result"][0]["start"]) + line_end_time = recognizer.segments_start_time + timedelta(seconds=final_result_json["result"][-1]["end"]) + + write_line(final_result_json, line_start_time, line_end_time, db_cursor) + + recognizer.Reset() diff --git a/buscribe/buscribe/main.py b/buscribe/buscribe/main.py new file mode 100644 index 0000000..fd935a2 --- /dev/null +++ b/buscribe/buscribe/main.py @@ -0,0 +1,95 @@ +import logging +import os +from datetime import timedelta, datetime +from time import sleep + +import argh +import common +from common import dateutil +from common.database import DBManager + +from buscribe.buscribe import get_end_of_transcript, transcribe_segments, finish_off_recognizer +from buscribe.recognizer import BuscribeRecognizer + + +@argh.arg('--database', + help='Postgres conection string for database to write transcribed lines to. Either a space-separated list of ' + 'key=value pairs, or a URI like: postgresql://USER:PASSWORD@HOST/DBNAME?KEY=VALUE .') +@argh.arg('--model', + help='Path to STT model files. Defaults to /usr/share/buscribe/vosk-model-en-us-0.21/') +@argh.arg('--spk-model', + help='Path to speaker recognition model files. Defaults to /usr/share/buscribe/vosk-model-spk-0.4/') +@argh.arg('--start-time', + help='Start time of the transcript. Buscript will try to start reading 2 min before this time, if available, ' + 'to prime the model. The transcripts for that time will not be written to the database. If not given ' + 'transcription will start after last already transcribed line.') +@argh.arg('--end-time', + help='End of transcript. If not given continues to transcribe live.') +@argh.arg('--base-dir', + help='Directory from which segments will be grabbed. Default is current working directory.') +def main(database="", base_dir=".", + model="/usr/share/buscribe/vosk-model-en-us-0.21/", spk_model="/usr/share/buscribe/vosk-model-spk-0.4/", + start_time=None, end_time=None): + SAMPLE_RATE = 48000 + segments_dir = os.path.join(base_dir, "desertbus", "source") + + logging.debug("Grabbing database...") + db_manager = DBManager(dsn=database) + db_conn = db_manager.get_conn() + db_cursor = db_conn.cursor() + logging.debug("Got database cursor.") + + logging.info("Figuring out starting time...") + if start_time is not None: + start_time = dateutil.parse(start_time) + else: + start_time = get_end_of_transcript(db_cursor) + + if end_time is not None: + end_time = dateutil.parse(end_time) + + # No start time argument AND no end of transcript (empty database) + if start_time is None: + logging.error("Couldn't figure out start time!") + db_conn.close() + exit(1) + + logging.info("Loading models...") + recognizer = BuscribeRecognizer(SAMPLE_RATE, model, spk_model) + logging.info("Models loaded.") + + logging.info('Transcribing from {}'.format(start_time)) + + # Start priming the recognizer if possible + start_time -= timedelta(minutes=2) + + while True: + # If end time isn't given, use current time (plus fudge) to get a "live" segment list + segments = common.get_best_segments(segments_dir, + start_time, + end_time if end_time is not None else datetime.now() + timedelta(minutes=2)) + # Remove initial None segment if it exists + if segments[0] is None: + segments = segments[1:] + + if recognizer.segments_start_time is None: + recognizer.segments_start_time = segments[0].start + + segments_end_time = transcribe_segments(segments, SAMPLE_RATE, recognizer, start_time, db_cursor) + + if end_time is not None and segments_end_time >= end_time: + # Work's done! + finish_off_recognizer(recognizer, db_cursor) + db_conn.close() + exit(0) + elif datetime.now() - segments_end_time > timedelta(minutes=5): + # Last seen segment ended more than five minutes ago. We hit a gap that will likely stay unfilled. + # Reset and jump to the other end of the gap. + finish_off_recognizer(recognizer, db_cursor) + else: + # End of live segment or a gap that is not old and might get filled. + # Give it a bit of time and continue. + # Note: if the gap is not filled within 30s, we jump to the next available segment. + sleep(30) + + start_time = segments_end_time diff --git a/buscribe/buscribe/recognizer.py b/buscribe/buscribe/recognizer.py new file mode 100644 index 0000000..888c77e --- /dev/null +++ b/buscribe/buscribe/recognizer.py @@ -0,0 +1,23 @@ +from vosk import Model, SpkModel, KaldiRecognizer + + +class BuscribeRecognizer(KaldiRecognizer): + segments_start_time = None + + def __init__(self, sample_rate=48000, model_path="model_small", spk_model_path="spk_model"): + """Loads the speech recognition model and initializes the recognizer. + + Model paths are file paths to the directories that contain the models. + + Returns a recognizer object. + """ + self.model = Model(model_path) + self.spk_model = SpkModel(spk_model_path) + + super(BuscribeRecognizer, self).__init__(self.model, sample_rate, self.spk_model) + + self.SetWords(True) + + def Reset(self): + super(BuscribeRecognizer, self).Reset() + self.segments_start_time = None diff --git a/buscribe/setup.py b/buscribe/setup.py new file mode 100644 index 0000000..d0de350 --- /dev/null +++ b/buscribe/setup.py @@ -0,0 +1,16 @@ +from setuptools import setup, find_packages + +setup( + name = "wubloader-buscribe", + version = "0.0.0", + packages = find_packages(), + install_requires = [ + "argh", + "psycopg2", + "greenlet==0.4.16", + "psycogreen", + "wubloader-common", + "python-dateutil", + "vosk" + ], +) diff --git a/buscribe_data.sql b/buscribe_data.sql new file mode 100644 index 0000000..8459552 --- /dev/null +++ b/buscribe_data.sql @@ -0,0 +1,11 @@ +DROP TABLE buscribe_transcriptions; + +CREATE TABLE buscribe_transcriptions +( + id BIGSERIAL PRIMARY KEY, + start_time timestamp without time zone NOT NULL, + end_time timestamp without time zone NOT NULL, + transcription_line text NOT NULL, + line_speaker float[128], + transcription_json jsonb NOT NULL +); \ No newline at end of file diff --git a/common/common/__init__.py b/common/common/__init__.py new file mode 100644 index 0000000..299db19 --- /dev/null +++ b/common/common/__init__.py @@ -0,0 +1,124 @@ + +"""A place for common utilities between wubloader components""" +import datetime +import errno +import os +import random + +from .segments import get_best_segments, rough_cut_segments, fast_cut_segments, full_cut_segments, parse_segment_path, SegmentInfo +from .stats import timed, PromLogCountsHandler, install_stacksampler + + +def dt_to_bustime(start, dt): + """Convert a datetime to bus time. Bus time is seconds since the given start point.""" + return (dt - start).total_seconds() + + +def bustime_to_dt(start, bustime): + """Convert from bus time to a datetime""" + return start + datetime.timedelta(seconds=bustime) + + +def parse_bustime(bustime): + """Convert from bus time human-readable string [-]HH:MM[:SS[.fff]] + to float seconds since bustime 00:00. Inverse of format_bustime(), + see it for detail.""" + if bustime.startswith('-'): + # parse without the -, then negate it + return -parse_bustime(bustime[1:]) + + parts = bustime.strip().split(':') + if len(parts) == 2: + hours, mins = parts + secs = 0 + elif len(parts) == 3: + hours, mins, secs = parts + else: + raise ValueError("Invalid bustime: must be HH:MM[:SS]") + hours = int(hours) + mins = int(mins) + secs = float(secs) + return 3600 * hours + 60 * mins + secs + + +def format_bustime(bustime, round="millisecond"): + """Convert bustime to a human-readable string (-)HH:MM:SS.fff, with the + ending cut off depending on the value of round: + "millisecond": (default) Round to the nearest millisecond. + "second": Round down to the current second. + "minute": Round down to the current minute. + Examples: + 00:00:00.000 + 01:23:00 + 110:50 + 159:59:59.999 + -10:30:01.100 + Negative times are formatted as time-until-start, preceeded by a minus + sign. + eg. "-1:20:00" indicates the run begins in 80 minutes. + """ + sign = '' + if bustime < 0: + sign = '-' + bustime = -bustime + total_mins, secs = divmod(bustime, 60) + hours, mins = divmod(total_mins, 60) + parts = [ + "{:02d}".format(int(hours)), + "{:02d}".format(int(mins)), + ] + if round == "minute": + pass + elif round == "second": + parts.append("{:02d}".format(int(secs))) + elif round == "millisecond": + parts.append("{:06.3f}".format(secs)) + else: + raise ValueError("Bad rounding value: {!r}".format(round)) + return sign + ":".join(parts) + + +def rename(old, new): + """Atomic rename that succeeds if the target already exists, since we're naming everything + by hash anyway, so if the filepath already exists the file itself is already there. + In this case, we delete the source file. + """ + try: + os.rename(old, new) + except OSError as e: + if e.errno != errno.EEXIST: + raise + os.remove(old) + + +def ensure_directory(path): + """Create directory that contains path, as well as any parent directories, + if they don't already exist.""" + dir_path = os.path.dirname(path) + os.makedirs(dir_path, exist_ok=True) + + +def jitter(interval): + """Apply some 'jitter' to an interval. This is a random +/- 10% change in order to + smooth out patterns and prevent everything from retrying at the same time. + """ + return interval * (0.9 + 0.2 * random.random()) + + +def writeall(write, value): + """Helper for writing a complete string to a file-like object. + Pass the write function and the value to write, and it will loop if needed to ensure + all data is written. + Works for both text and binary files, as long as you pass the right value type for + the write function. + """ + while value: + n = write(value) + if n is None: + # The write func doesn't return the amount written, assume it always writes everything + break + if n == 0: + # This would cause an infinite loop...blow up instead so it's clear what the problem is + raise Exception("Wrote 0 chars while calling {} with {}-char {}".format(write, len(value), type(value).__name__)) + # remove the first n chars and go again if we have anything left + value = value[n:] diff --git a/common/common/database.py b/common/common/database.py new file mode 100644 index 0000000..8bac136 --- /dev/null +++ b/common/common/database.py @@ -0,0 +1,73 @@ + +""" +Code shared between components that touch the database. +Note that this code requires psycopg2 and psycogreen, but the common module +as a whole does not to avoid needing to install them for components that don't need it. +""" + +from contextlib import contextmanager + +import psycopg2 +import psycopg2.extensions +import psycopg2.extras +from psycogreen.gevent import patch_psycopg + + +class DBManager(object): + """Patches psycopg2 before any connections are created. Stores connect info + for easy creation of new connections, and sets some defaults before + returning them. + + It has the ability to serve as a primitive connection pool, as getting a + new conn will return existing conns it knows about first, but you + should use a real conn pool for any non-trivial use. + + Returned conns are set to seralizable isolation level, autocommit, and use + NamedTupleCursor cursors.""" + def __init__(self, connect_timeout=30, **connect_kwargs): + patch_psycopg() + self.conns = [] + self.connect_timeout = connect_timeout + self.connect_kwargs = connect_kwargs + + def put_conn(self, conn): + self.conns.append(conn) + + def get_conn(self): + if self.conns: + return self.conns.pop(0) + conn = psycopg2.connect(cursor_factory=psycopg2.extras.NamedTupleCursor, + connect_timeout=self.connect_timeout, **self.connect_kwargs) + # We use serializable because it means less issues to think about, + # we don't care about the performance concerns and everything we do is easily retryable. + # This shouldn't matter in practice anyway since everything we're doing is either read-only + # searches or targetted single-row updates. + conn.isolation_level = psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE + conn.autocommit = True + return conn + + +@contextmanager +def transaction(conn): + """Helper context manager that runs the code block as a single database transaction + instead of in autocommit mode. The only difference between this and "with conn" is + that we explicitly disable then re-enable autocommit.""" + old_autocommit = conn.autocommit + conn.autocommit = False + try: + with conn: + yield + finally: + conn.autocommit = old_autocommit + + +def query(conn, query, *args, **kwargs): + """Helper that takes a conn, creates a cursor and executes query against it, + then returns the cursor. + Variables may be given as positional or keyword args (but not both), corresponding + to %s vs %(key)s placeholder forms.""" + if args and kwargs: + raise TypeError("Cannot give both args and kwargs") + cur = conn.cursor() + cur.execute(query, args or kwargs or None) + return cur diff --git a/common/common/dateutil.py b/common/common/dateutil.py new file mode 100644 index 0000000..7793e6e --- /dev/null +++ b/common/common/dateutil.py @@ -0,0 +1,23 @@ + + +"""Wrapper code around dateutil to use it more sanely""" + + +# required so we are able to import dateutil despite this module also being called dateutil +from __future__ import absolute_import + +import dateutil.parser +import dateutil.tz + + +def parse(timestamp): + """Parse given timestamp, convert to UTC, and return naive UTC datetime""" + dt = dateutil.parser.parse(timestamp) + if dt.tzinfo is not None: + dt = dt.astimezone(dateutil.tz.tzutc()).replace(tzinfo=None) + return dt + + +def parse_utc_only(timestamp): + """Parse given timestamp, but assume it's already in UTC and ignore other timezone info""" + return dateutil.parser.parse(timestamp, ignoretz=True) diff --git a/common/common/flask_stats.py b/common/common/flask_stats.py new file mode 100644 index 0000000..e74b985 --- /dev/null +++ b/common/common/flask_stats.py @@ -0,0 +1,98 @@ +""" +Code shared between components to gather stats from flask methods. +Note that this code requires flask, but the common module as a whole does not +to avoid needing to install them for components that don't need it. +""" + +import functools + +from flask import request +from flask import g as request_store +from monotonic import monotonic +import prometheus_client as prom + + +# Generic metrics that all http requests get logged to (see below for specific metrics per endpoint) + +LATENCY_HELP = "Time taken to run the request handler and create a response" +# buckets: very long playlists / cutting can be quite slow, +# so we have a wider range of latencies than default, up to 10min. +LATENCY_BUCKETS = [.001, .005, .01, .05, .1, .5, 1, 5, 10, 30, 60, 120, 300, 600] +generic_latency = prom.Histogram( + 'http_request_latency_all', LATENCY_HELP, + ['endpoint', 'method', 'status'], + buckets=LATENCY_BUCKETS, +) + +CONCURRENT_HELP = 'Number of requests currently ongoing' +generic_concurrent = prom.Gauge( + 'http_request_concurrency_all', CONCURRENT_HELP, + ['endpoint', 'method'], +) + + +def request_stats(fn): + """Decorator that wraps a handler func to collect metrics. + Adds handler func args as labels, along with 'endpoint' label using func's name, + method and response status where applicable.""" + # We have to jump through some hoops here, because the prometheus client lib demands + # we pre-define our label names, but we don't know the names of the handler kwargs + # until the first time the function's called. So we delay defining the metrics until + # first call. + # In addition, it doesn't let us have different sets of labels with the same name. + # So we record everything twice: Once under a generic name with only endpoint, method + # and status, and once under a name specific to the endpoint with the full set of labels. + metrics = {} + endpoint = fn.__name__ + + @functools.wraps(fn) + def _stats(**kwargs): + if not metrics: + # first call, set up metrics + labels_no_status = sorted(kwargs.keys()) + ['endpoint', 'method'] + labels = labels_no_status + ['status'] + metrics['latency'] = prom.Histogram( + 'http_request_latency_{}'.format(endpoint), LATENCY_HELP, + labels, buckets=LATENCY_BUCKETS, + ) + metrics['concurrent'] = prom.Gauge( + 'http_request_concurrency_{}'.format(endpoint), CONCURRENT_HELP, + labels_no_status, + ) + + request_store.metrics = metrics + request_store.endpoint = endpoint + request_store.method = request.method + request_store.labels = {k: str(v) for k, v in kwargs.items()} + generic_concurrent.labels(endpoint=endpoint, method=request.method).inc() + metrics['concurrent'].labels(endpoint=endpoint, method=request.method, **request_store.labels).inc() + request_store.start_time = monotonic() + return fn(**kwargs) + + return _stats + + +def after_request(response): + """Must be registered to run after requests. Finishes tracking the request + and logs most of the metrics. + We do it in this way, instead of inside the request_stats wrapper, because it lets flask + normalize the handler result into a Response object. + """ + if 'metrics' not in request_store: + return response # untracked handler + + end_time = monotonic() + metrics = request_store.metrics + endpoint = request_store.endpoint + method = request_store.method + labels = request_store.labels + start_time = request_store.start_time + + generic_concurrent.labels(endpoint=endpoint, method=method).dec() + metrics['concurrent'].labels(endpoint=endpoint, method=method, **labels).dec() + + status = str(response.status_code) + generic_latency.labels(endpoint=endpoint, method=method, status=status).observe(end_time - start_time) + metrics['latency'].labels(endpoint=endpoint, method=method, status=status, **labels).observe(end_time - start_time) + + return response diff --git a/common/common/googleapis.py b/common/common/googleapis.py new file mode 100644 index 0000000..d7aa292 --- /dev/null +++ b/common/common/googleapis.py @@ -0,0 +1,67 @@ + +import time +import logging + +import gevent + +from .requests import InstrumentedSession + +# Wraps all requests in some metric collection +requests = InstrumentedSession() + + +class GoogleAPIClient(object): + """Manages access to google apis and maintains an active access token. + Make calls using client.request(), which is a wrapper for requests.request(). + """ + + ACCESS_TOKEN_ERROR_RETRY_INTERVAL = 10 + # Refresh token 10min before it expires (it normally lasts an hour) + ACCESS_TOKEN_REFRESH_TIME_BEFORE_EXPIRY = 600 + + def __init__(self, client_id, client_secret, refresh_token): + self.client_id = client_id + self.client_secret = client_secret + self.refresh_token = refresh_token + + self._first_get_access_token = gevent.spawn(self.get_access_token) + + @property + def access_token(self): + """Blocks if access token unavailable yet""" + self._first_get_access_token.join() + return self._access_token + + def get_access_token(self): + """Authenticates against google's API and retrieves a token we will use in + subsequent requests. + This function gets called automatically when needed, there should be no need to call it + yourself.""" + while True: + try: + start_time = time.time() + resp = requests.post('https://www.googleapis.com/oauth2/v4/token', data={ + 'client_id': self.client_id, + 'client_secret': self.client_secret, + 'refresh_token': self.refresh_token, + 'grant_type': 'refresh_token', + }, metric_name='get_access_token') + resp.raise_for_status() + data = resp.json() + self._access_token = data['access_token'] + expires_in = (start_time + data['expires_in']) - time.time() + if expires_in < self.ACCESS_TOKEN_REFRESH_TIME_BEFORE_EXPIRY: + self.logger.warning("Access token expires in {}s, less than normal leeway time of {}s".format( + expires_in, self.ACCESS_TOKEN_REFRESH_TIME_BEFORE_EXPIRY, + )) + gevent.spawn_later(expires_in - self.ACCESS_TOKEN_REFRESH_TIME_BEFORE_EXPIRY, self.get_access_token) + except Exception: + logging.exception("Failed to fetch access token, retrying") + gevent.sleep(self.ACCESS_TOKEN_ERROR_RETRY_INTERVAL) + else: + break + + def request(self, method, url, headers={}, **kwargs): + # merge in auth header + headers = dict(headers, Authorization='Bearer {}'.format(self.access_token)) + return requests.request(method, url, headers=headers, **kwargs) diff --git a/common/common/requests.py b/common/common/requests.py new file mode 100644 index 0000000..194dafb --- /dev/null +++ b/common/common/requests.py @@ -0,0 +1,55 @@ + +"""Code for instrumenting requests calls. Requires requests, obviously.""" + +import urllib.parse + +import requests.sessions +import prometheus_client as prom +from monotonic import monotonic + +request_latency = prom.Histogram( + 'http_client_request_latency', + 'Time taken to make an outgoing HTTP request. ' + 'Status = "error" is used if an error occurs. Measured as time from first byte sent to ' + 'headers finished being parsed, ie. does not include reading a streaming response.', + ['name', 'method', 'domain', 'status'], +) + +response_size = prom.Histogram( + 'http_client_response_size', + "The content length of (non-streaming) responses to outgoing HTTP requests.", + ['name', 'method', 'domain', 'status'], +) + +request_concurrency = prom.Gauge( + 'http_client_request_concurrency', + "The number of outgoing HTTP requests currently ongoing", + ['name', 'method', 'domain'], +) + +class InstrumentedSession(requests.sessions.Session): + """A requests Session that automatically records metrics on requests made. + Users may optionally pass a 'metric_name' kwarg that will be included as the 'name' label. + """ + + def request(self, method, url, *args, **kwargs): + _, domain, _, _, _ = urllib.parse.urlsplit(url) + name = kwargs.pop('metric_name', '') + + start = monotonic() # we only use our own measured latency if an error occurs + try: + with request_concurrency.labels(name, method, domain).track_inprogress(): + response = super().request(method, url, *args, **kwargs) + except Exception: + latency = monotonic() - start + request_latency.labels(name, method, domain, "error").observe(latency) + raise + + request_latency.labels(name, method, domain, response.status_code).observe(response.elapsed.total_seconds()) + try: + content_length = int(response.headers['content-length']) + except (KeyError, ValueError): + pass # either not present or not valid + else: + response_size.labels(name, method, domain, response.status_code).observe(content_length) + return response diff --git a/common/common/segments.py b/common/common/segments.py new file mode 100644 index 0000000..7a6e1da --- /dev/null +++ b/common/common/segments.py @@ -0,0 +1,513 @@ + +"""A place for common utilities between wubloader components""" + + +import base64 +import datetime +import errno +import itertools +import json +import logging +import os +import shutil +from collections import namedtuple +from contextlib import closing +from tempfile import TemporaryFile + +import gevent +from gevent import subprocess + +from .stats import timed + + +def unpadded_b64_decode(s): + """Decode base64-encoded string that has had its padding removed. + Note it takes a unicode and returns a bytes.""" + # right-pad with '=' to multiple of 4 + s = s + '=' * (- len(s) % 4) + return base64.b64decode(s.encode(), b"-_") + + +class SegmentInfo( + namedtuple('SegmentInfoBase', [ + 'path', 'channel', 'quality', 'start', 'duration', 'type', 'hash' + ]) +): + """Info parsed from a segment path, including original path. + Note that start time is a datetime and duration is a timedelta, and hash is a decoded binary string.""" + @property + def end(self): + return self.start + self.duration + @property + def is_partial(self): + """Note that suspect is considered partial""" + return self.type != "full" + + +def parse_segment_timestamp(hour_str, min_str): + """This is faster than strptime, which dominates our segment processing time. + It takes strictly formatted hour = "%Y-%m-%dT%H" and time = "%M:%S.%f".""" + year = int(hour_str[0:4]) + month = int(hour_str[5:7]) + day = int(hour_str[8:10]) + hour = int(hour_str[11:13]) + min = int(min_str[0:2]) + sec = int(min_str[3:5]) + microsec_str = min_str[6:] + microsec_str += '0' * (6 - len(microsec_str)) # right-pad zeros to 6 digits, eg. "123" -> "123000" + microsec = int(microsec_str) + return datetime.datetime(year, month, day, hour, min, sec, microsec) + + +def parse_segment_path(path): + """Parse segment path, returning a SegmentInfo. If path is only the trailing part, + eg. just a filename, it will leave unknown fields as None.""" + parts = path.split('/') + # left-pad parts with None up to 4 parts + parts = [None] * (4 - len(parts)) + parts + # pull info out of path parts + channel, quality, hour, filename = parts[-4:] + # split filename, which should be TIME-DURATION-TYPE-HASH.ts + try: + if not filename.endswith('.ts'): + raise ValueError("Does not end in .ts") + filename = filename[:-len('.ts')] # chop off .ts + parts = filename.split('-', 3) + if len(parts) != 4: + raise ValueError("Not enough dashes in filename") + time, duration, type, hash = parts + if type not in ('full', 'suspect', 'partial', 'temp'): + raise ValueError("Unknown type {!r}".format(type)) + hash = None if type == 'temp' else unpadded_b64_decode(hash) + start = None if hour is None else parse_segment_timestamp(hour, time) + return SegmentInfo( + path = path, + channel = channel, + quality = quality, + start = start, + duration = datetime.timedelta(seconds=float(duration)), + type = type, + hash = hash, + ) + except ValueError as e: + # wrap error but preserve original traceback + raise ValueError("Bad path {!r}: {}".format(path, e)).with_traceback(e.__traceback__) + + +class ContainsHoles(Exception): + """Raised by get_best_segments() when a hole is found and allow_holes is False""" + + +@timed( + hours_path=lambda ret, hours_path, *args, **kwargs: hours_path, + has_holes=lambda ret, *args, **kwargs: None in ret, + normalize=lambda ret, *args, **kwargs: len([x for x in ret if x is not None]), +) +def get_best_segments(hours_path, start, end, allow_holes=True): + """Return a list of the best sequence of non-overlapping segments + we have for a given time range. Hours path should be the directory containing hour directories. + Time args start and end should be given as datetime objects. + The first segment may start before the time range, and the last may end after it. + The returned list contains items that are either: + SegmentInfo: a segment + None: represents a discontinuity between the previous segment and the next one. + ie. as long as two segments appear next to each other, we guarentee there is no gap between + them, the second one starts right as the first one finishes. + Similarly, unless the first item is None, the first segment starts <= the start of the time + range, and unless the last item is None, the last segment ends >= the end of the time range. + Example: + Suppose you ask for a time range from 10 to 60. We have 10-second segments covering + the following times: + 5 to 15 + 15 to 25 + 30 to 40 + 40 to 50 + Then the output would look like: + segment from 5 to 15 + segment from 15 to 25 + None, as the previous segment ends 5sec before the next one begins + segment from 30 to 40 + segment from 40 to 50 + None, as the previous segment ends 10sec before the requested end time of 60. + Note that any is_partial=True segment will be followed by a None, since we can't guarentee + it joins on to the next segment fully intact. + + If allow_holes is False, then we fail fast at the first discontinuity found + and raise ContainsHoles. If ContainsHoles is not raised, the output is guarenteed to not contain + any None items. + """ + # Note: The exact equality checks in this function are not vulnerable to floating point error, + # but only because all input dates and durations are only precise to the millisecond, and + # python's datetime types represent these as integer microseconds internally. So the parsing + # to these types is exact, and all operations on them are exact, so all operations are exact. + + result = [] + + for hour in hour_paths_for_range(hours_path, start, end): + # Especially when processing multiple hours, this routine can take a signifigant amount + # of time with no blocking. To ensure other stuff is still completed in a timely fashion, + # we yield to let other things run. + gevent.idle() + + # best_segments_by_start will give us the best available segment for each unique start time + for segment in best_segments_by_start(hour): + + # special case: first segment + if not result: + # first segment is allowed to be before start as long as it includes it + if segment.start <= start < segment.end: + # segment covers start + result.append(segment) + elif start < segment.start < end: + # segment is after start (but before end), so there was no segment that covers start + # so we begin with a None + if not allow_holes: + raise ContainsHoles + result.append(None) + result.append(segment) + else: + # segment is before start, and doesn't cover start, or starts after end. + # ignore and go to next. + continue + else: + # normal case: check against previous segment end time + prev_end = result[-1].end + if segment.start < prev_end: + # Overlap! This shouldn't happen, though it might be possible due to weirdness + # if the stream drops then starts again quickly. We simply ignore the overlapping + # segment and let the algorithm continue. + logging.warning("Overlapping segments: {} overlaps end of {}".format(segment, result[-1])) + continue + if result[-1].is_partial or prev_end < segment.start: + # there's a gap between prev end and this start, so add a None + if not allow_holes: + raise ContainsHoles + result.append(None) + result.append(segment) + + # check if we've reached the end + if end <= segment.end: + break + + # this is a weird little construct that says "if we broke from the inner loop, + # then also break from the outer one. otherwise continue." + else: + continue + break + + # check if we need a trailing None because last segment is partial or doesn't reach end, + # or we found nothing at all + if not result or result[-1].is_partial or result[-1].end < end: + if not allow_holes: + raise ContainsHoles + result.append(None) + + return result + + +def hour_paths_for_range(hours_path, start, end): + """Generate a list of hour paths to check when looking for segments between start and end.""" + # truncate start and end to the hour + def truncate(dt): + return dt.replace(microsecond=0, second=0, minute=0) + current = truncate(start) + end = truncate(end) + # Begin in the hour prior to start, as there may be a segment that starts in that hour + # but contains the start time, eg. if the start time is 01:00:01 and there's a segment + # at 00:59:59 which goes for 3 seconds. + # Checking the entire hour when in most cases it won't be needed is wasteful, but it's also + # pretty quick and the complexity of only checking this case when needed just isn't worth it. + current -= datetime.timedelta(hours=1) + while current <= end: + yield os.path.join(hours_path, current.strftime("%Y-%m-%dT%H")) + current += datetime.timedelta(hours=1) + + +def best_segments_by_start(hour): + """Within a given hour path, yield the "best" segment per unique segment start time. + Best is defined as type=full, or failing that type=suspect, or failing that the longest type=partial. + Note this means this function may perform os.stat()s. + """ + try: + segment_paths = os.listdir(hour) + except OSError as e: + if e.errno != errno.ENOENT: + raise + # path does not exist, treat it as having no files + return + segment_paths.sort() + # raise a warning for any files that don't parse as segments and ignore them + parsed = [] + for name in segment_paths: + try: + parsed.append(parse_segment_path(os.path.join(hour, name))) + except ValueError: + logging.warning("Failed to parse segment {!r}".format(os.path.join(hour, name)), exc_info=True) + + for start_time, segments in itertools.groupby(parsed, key=lambda segment: segment.start): + # ignore temp segments as they might go away by the time we want to use them + segments = [segment for segment in segments if segment.type != "temp"] + if not segments: + # all segments were temp, move on + continue + + full_segments = [segment for segment in segments if not segment.is_partial] + if full_segments: + if len(full_segments) != 1: + logging.info("Multiple versions of full segment at start_time {}: {}".format( + start_time, ", ".join(map(str, segments)) + )) + # We've observed some cases where the same segment (with the same hash) will be reported + # with different durations (generally at stream end). Prefer the longer duration (followed by longest size), + # as this will ensure that if hashes are different we get the most data, and if they + # are the same it should keep holes to a minimum. + # If same duration and size, we have to pick one, so pick highest-sorting hash just so we're consistent. + sizes = {segment: os.stat(segment.path).st_size for segment in segments} + full_segments = [max(full_segments, key=lambda segment: (segment.duration, sizes[segment], segment.hash))] + yield full_segments[0] + continue + # no full segments, fall back to measuring partials. Prefer suspect over partial. + yield max(segments, key=lambda segment: ( + 1 if segment.type == 'suspect' else 0, + os.stat(segment.path).st_size, + )) + + +def streams_info(segment): + """Return ffprobe's info on streams as a list of dicts""" + output = subprocess.check_output([ + 'ffprobe', + '-hide_banner', '-loglevel', 'fatal', # suppress noisy output + '-of', 'json', '-show_streams', # get streams info as json + segment.path, + ]) + # output here is a bytes, but json.loads will accept it + return json.loads(output)['streams'] + + +def ffmpeg_cut_segment(segment, cut_start=None, cut_end=None): + """Return a Popen object which is ffmpeg cutting the given single segment. + This is used when doing a fast cut. + """ + args = [ + 'ffmpeg', + '-hide_banner', '-loglevel', 'error', # suppress noisy output + '-i', segment.path, + ] + # output from ffprobe is generally already sorted but let's be paranoid, + # because the order of map args matters. + for stream in sorted(streams_info(segment), key=lambda stream: stream['index']): + # map the same stream in the same position from input to output + args += ['-map', '0:{}'.format(stream['index'])] + if stream['codec_type'] in ('video', 'audio'): + # for non-metadata streams, make sure we use the same codec (metadata streams + # are a bit weirder, and ffmpeg will do the right thing anyway) + args += ['-codec:{}'.format(stream['index']), stream['codec_name']] + # now add trim args + if cut_start: + args += ['-ss', str(cut_start)] + if cut_end: + args += ['-to', str(cut_end)] + # output to stdout as MPEG-TS + args += ['-f', 'mpegts', '-'] + # run it + logging.info("Running segment cut with args: {}".format(" ".join(args))) + return subprocess.Popen(args, stdout=subprocess.PIPE) + + +def ffmpeg_cut_stdin(output_file, cut_start, duration, encode_args): + """Return a Popen object which is ffmpeg cutting from stdin. + This is used when doing a full cut. + If output_file is not subprocess.PIPE, + uses explicit output file object instead of using a pipe, + because some video formats require a seekable file. + """ + args = [ + 'ffmpeg', + '-hide_banner', '-loglevel', 'error', # suppress noisy output + '-i', '-', + '-ss', cut_start, + '-t', duration, + ] + list(encode_args) + if output_file is subprocess.PIPE: + args.append('-') # output to stdout + else: + args += [ + # We want ffmpeg to write to our tempfile, which is its stdout. + # However, it assumes that '-' means the output is not seekable. + # We trick it into understanding that its stdout is seekable by + # telling it to write to the fd via its /proc/self filename. + '/proc/self/fd/1', + # But of course, that file "already exists", so we need to give it + # permission to "overwrite" it. + '-y', + ] + args = map(str, args) + logging.info("Running full cut with args: {}".format(" ".join(args))) + return subprocess.Popen(args, stdin=subprocess.PIPE, stdout=output_file) + + +def read_chunks(fileobj, chunk_size=16*1024): + """Read fileobj until EOF, yielding chunk_size sized chunks of data.""" + while True: + chunk = fileobj.read(chunk_size) + if not chunk: + break + yield chunk + + +@timed('cut', cut_type='rough', normalize=lambda _, segments, start, end: (end - start).total_seconds()) +def rough_cut_segments(segments, start, end): + """Yields chunks of a MPEGTS video file covering at least the timestamp range, + likely with a few extra seconds on either side. + This method works by simply concatenating all the segments, without any re-encoding. + """ + for segment in segments: + with open(segment.path, 'rb') as f: + for chunk in read_chunks(f): + yield chunk + + +@timed('cut', cut_type='fast', normalize=lambda _, segments, start, end: (end - start).total_seconds()) +def fast_cut_segments(segments, start, end): + """Yields chunks of a MPEGTS video file covering the exact timestamp range. + segments should be a list of segments as returned by get_best_segments(). + This method works by only cutting the first and last segments, and concatenating the rest. + This only works if the same codec settings etc are used across all segments. + This should almost always be true but may cause weird results if not. + """ + + # how far into the first segment to begin (if no hole at start) + cut_start = None + if segments[0] is not None: + cut_start = (start - segments[0].start).total_seconds() + if cut_start < 0: + raise ValueError("First segment doesn't begin until after cut start, but no leading hole indicated") + + # how far into the final segment to end (if no hole at end) + cut_end = None + if segments[-1] is not None: + cut_end = (end - segments[-1].start).total_seconds() + if cut_end < 0: + raise ValueError("Last segment ends before cut end, but no trailing hole indicated") + + # Set first and last only if they actually need cutting. + # Note this handles both the cut_start = None (no first segment to cut) + # and cut_start = 0 (first segment already starts on time) cases. + first = segments[0] if cut_start else None + last = segments[-1] if cut_end else None + + for segment in segments: + if segment is None: + logging.debug("Skipping discontinuity while cutting") + # TODO: If we want to be safe against the possibility of codecs changing, + # we should check the streams_info() after each discontinuity. + continue + + # note first and last might be the same segment. + # note a segment will only match if cutting actually needs to be done + # (ie. cut_start or cut_end is not 0) + if segment in (first, last): + proc = None + try: + proc = ffmpeg_cut_segment( + segment, + cut_start if segment == first else None, + cut_end if segment == last else None, + ) + with closing(proc.stdout): + for chunk in read_chunks(proc.stdout): + yield chunk + proc.wait() + except Exception as ex: + # try to clean up proc, ignoring errors + if proc is not None: + try: + proc.kill() + except OSError: + pass + raise ex + else: + # check if ffmpeg had errors + if proc.returncode != 0: + raise Exception( + "Error while streaming cut: ffmpeg exited {}".format(proc.returncode) + ) + else: + # no cutting needed, just serve the file + with open(segment.path, 'rb') as f: + for chunk in read_chunks(f): + yield chunk + + +def feed_input(segments, pipe): + """Write each segment's data into the given pipe in order. + This is used to provide input to ffmpeg in a full cut.""" + for segment in segments: + with open(segment.path, 'rb') as f: + try: + shutil.copyfileobj(f, pipe) + except OSError as e: + # ignore EPIPE, as this just means the end cut meant we didn't need all it + if e.errno != errno.EPIPE: + raise + pipe.close() + + +@timed('cut', + cut_type=lambda _, segments, start, end, encode_args, stream=False: ("full-streamed" if stream else "full-buffered"), + normalize=lambda _, segments, start, end, *a, **k: (end - start).total_seconds(), +) +def full_cut_segments(segments, start, end, encode_args, stream=False): + """If stream=true, assume encode_args gives a streamable format, + and begin returning output immediately instead of waiting for ffmpeg to finish + and buffering to disk.""" + + # Remove holes + segments = [segment for segment in segments if segment is not None] + + # how far into the first segment to begin + cut_start = max(0, (start - segments[0].start).total_seconds()) + # duration + duration = (end - start).total_seconds() + + ffmpeg = None + input_feeder = None + try: + + if stream: + # When streaming, we can just use a pipe + tempfile = subprocess.PIPE + else: + # Some ffmpeg output formats require a seekable file. + # For the same reason, it's not safe to begin uploading until ffmpeg + # has finished. We create a temporary file for this. + tempfile = TemporaryFile() + + ffmpeg = ffmpeg_cut_stdin(tempfile, cut_start, duration, encode_args) + input_feeder = gevent.spawn(feed_input, segments, ffmpeg.stdin) + + # When streaming, we can return data as it is available + if stream: + for chunk in read_chunks(ffmpeg.stdout): + yield chunk + + # check if any errors occurred in input writing, or if ffmpeg exited non-success. + if ffmpeg.wait() != 0: + raise Exception("Error while streaming cut: ffmpeg exited {}".format(ffmpeg.returncode)) + input_feeder.get() # re-raise any errors from feed_input() + + # When not streaming, we can only return the data once ffmpeg has exited + if not stream: + for chunk in read_chunks(tempfile): + yield chunk + finally: + # if something goes wrong, try to clean up ignoring errors + if input_feeder is not None: + input_feeder.kill() + if ffmpeg is not None and ffmpeg.poll() is None: + for action in (ffmpeg.kill, ffmpeg.stdin.close, ffmpeg.stdout.close): + try: + action() + except (OSError, IOError): + pass diff --git a/common/common/stats.py b/common/common/stats.py new file mode 100644 index 0000000..68777d3 --- /dev/null +++ b/common/common/stats.py @@ -0,0 +1,257 @@ + +import atexit +import functools +import logging +import os +import signal + +import gevent.lock +from monotonic import monotonic +import prometheus_client as prom + + +# need to keep global track of what metrics we've registered +# because we're not allowed to re-register +metrics = {} + + +def timed(name=None, + buckets=[10.**x for x in range(-9, 5)], normalized_buckets=None, + normalize=None, + **labels +): + """Decorator that instruments wrapped function to record real, user and system time + as a prometheus histogram. + + Metrics are recorded as NAME_latency, NAME_cputime{type=user} and NAME_cputime{type=system} + respectively. User and system time are process-wide (which means they'll be largely meaningless + if you're using gevent and the wrapped function blocks) and do not include subprocesses. + + NAME defaults to the wrapped function's name. + NAME must be unique OR have the exact same labels as other timed() calls with that name. + + Any labels passed in are included. Given label values may be callable, in which case + they are passed the input and result from the wrapped function and should return a label value. + Otherwise the given label value is used directly. All label values are automatically str()'d. + + In addition, the "error" label is automatically included, and set to "" if no exception + occurs, or the name of the exception type if one does. + + The normalize argument, if given, causes the creation of a second set of metrics + NAME_normalized_latency, etc. The normalize argument should be a callable which + takes the input and result of the wrapped function and returns a normalization factor. + All normalized metrics divide the observed times by this factor. + The intent is to allow a function which is expected to take longer given a larger input + to be timed on a per-input basis. + As a special case, when normalize returns 0 or None, normalized metrics are not updated. + + The buckets kwarg is as per prometheus_client.Histogram. The default is a conservative + but sparse range covering nanoseconds to hours. + The normalized_buckets kwarg applies to the normalized metrics, and defaults to the same + as buckets. + + All callables that take inputs and result take them as follows: The first arg is the result, + followed by *args and **kwargs as per the function's inputs. + If the wrapped function errored, result is None. + To simplify error handling in these functions, any errors are taken to mean None, + and None is interpreted as '' for label values. + + Contrived Example: + @timed("scanner", + # constant label + foo="my example label", + # label dependent on input + all=lambda results, predicate, list, find_all=False: find_all, + # label dependent on output + found=lambda results, *a, **k: len(found) > 0, + # normalized on input + normalize=lambda results, predicate, list, **k: len(list), + ) + def scanner(predicate, list, find_all=False): + results = [] + for item in list: + if predicate(item): + results.append(item) + if not find_all: + break + return results + """ + + if normalized_buckets is None: + normalized_buckets = buckets + # convert constant (non-callable) values into callables for consistency + labels = { + # need to create then call a function to properly bind v as otherwise it will + # always return the final label value. + k: v if callable(v) else (lambda v: (lambda *a, **k: v))(v) + for k, v in labels.items() + } + + def _timed(fn): + # can't safely assign to name inside closure, we use a new _name variable instead + _name = fn.__name__ if name is None else name + + if _name in metrics: + latency, cputime = metrics[_name] + else: + latency = prom.Histogram( + "{}_latency".format(_name), + "Wall clock time taken to execute {}".format(_name), + list(labels.keys()) + ['error'], + buckets=buckets, + ) + cputime = prom.Histogram( + "{}_cputime".format(_name), + "Process-wide consumed CPU time during execution of {}".format(_name), + list(labels.keys()) + ['error', 'type'], + buckets=buckets, + ) + metrics[_name] = latency, cputime + if normalize: + normname = '{} normalized'.format(_name) + if normname in metrics: + normal_latency, normal_cputime = metrics[normname] + else: + normal_latency = prom.Histogram( + "{}_latency_normalized".format(_name), + "Wall clock time taken to execute {} per unit of work".format(_name), + list(labels.keys()) + ['error'], + buckets=normalized_buckets, + ) + normal_cputime = prom.Histogram( + "{}_cputime_normalized".format(_name), + "Process-wide consumed CPU time during execution of {} per unit of work".format(_name), + list(labels.keys()) + ['error', 'type'], + buckets=normalized_buckets, + ) + metrics[normname] = normal_latency, normal_cputime + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + start_monotonic = monotonic() + start_user, start_sys, _, _, _ = os.times() + + try: + ret = fn(*args, **kwargs) + except Exception as e: + ret = None + error = e + else: + error = None + + end_monotonic = monotonic() + end_user, end_sys, _, _, _ = os.times() + wall_time = end_monotonic - start_monotonic + user_time = end_user - start_user + sys_time = end_sys - start_sys + + label_values = {} + for k, v in labels.items(): + try: + value = v(ret, *args, **kwargs) + except Exception: + value = None + label_values[k] = '' if value is None else str(value) + label_values.update(error='' if error is None else type(error).__name__) + + latency.labels(**label_values).observe(wall_time) + cputime.labels(type='user', **label_values).observe(user_time) + cputime.labels(type='system', **label_values).observe(sys_time) + if normalize: + try: + factor = normalize(ret, *args, **kwargs) + except Exception: + factor = None + if factor is not None and factor > 0: + normal_latency.labels(**label_values).observe(wall_time / factor) + normal_cputime.labels(type='user', **label_values).observe(user_time / factor) + normal_cputime.labels(type='system', **label_values).observe(sys_time / factor) + + if error is None: + return ret + raise error from None # re-raise error with original traceback + + return wrapper + + return _timed + + +log_count = prom.Counter("log_count", "Count of messages logged", ["level", "module", "function"]) + +class PromLogCountsHandler(logging.Handler): + """A logging handler that records a count of logs by level, module and function.""" + def emit(self, record): + log_count.labels(record.levelname, record.module, record.funcName).inc() + + @classmethod + def install(cls): + root_logger = logging.getLogger() + root_logger.addHandler(cls()) + + +def install_stacksampler(interval=0.005): + """Samples the stack every INTERVAL seconds of user time. + We could use user+sys time but that leads to interrupting syscalls, + which may affect performance, and we care mostly about user time anyway. + """ + if os.environ.get('WUBLOADER_ENABLE_STACKSAMPLER', '').lower() != 'true': + return + + logging.info("Installing stacksampler") + + # Note we only start each next timer once the previous timer signal has been processed. + # There are two reasons for this: + # 1. Avoid handling a signal while already handling a signal, however unlikely, + # as this could lead to a deadlock due to locking inside prometheus_client. + # 2. Avoid biasing the results by effectively not including the time taken to do the actual + # stack sampling. + + flamegraph = prom.Counter( + "flamegraph", + "Approx time consumed by each unique stack trace seen by sampling the stack", + ["stack"] + ) + # HACK: It's possible to deadlock if we handle a signal during a prometheus collect + # operation that locks our flamegraph metric. We then try to take the lock when recording the + # metric, but can't. + # As a hacky work around, we replace the lock with a dummy lock that doesn't actually lock anything. + # This is reasonably safe. We know that only one copy of sample() will ever run at once, + # and nothing else but sample() and collect() will touch the metric, leaving two possibilities: + # 1. Multiple collects happen at once: Safe. They only do read operations. + # 2. A sample during a collect: Safe. The collect only does a copy inside the locked part, + # so it just means it'll either get a copy with the new label set, or without it. + # This presumes the implementation doesn't change to make that different, however. + flamegraph._lock = gevent.lock.DummySemaphore() + # There is also a lock we need to bypass on the actual counter values themselves. + # Since they get created dynamically, this means we need to replace the lock function + # that is used to create them. + # This unfortunately means we go without locking for all metrics, not just this one, + # however this is safe because we are using gevent, not threading. The lock is only + # used to make incrementing/decrementing the counter thread-safe, which is not a concern + # under gevent since there are no switch points under the lock. + import prometheus_client.values + prometheus_client.values.Lock = gevent.lock.DummySemaphore + + + def sample(signum, frame): + stack = [] + while frame is not None: + stack.append(frame) + frame = frame.f_back + # format each frame as FUNCTION(MODULE) + stack = ";".join( + "{}({})".format(frame.f_code.co_name, frame.f_globals.get('__name__')) + for frame in stack[::-1] + ) + # increase counter by interval, so final units are in seconds + flamegraph.labels(stack).inc(interval) + # schedule the next signal + signal.setitimer(signal.ITIMER_VIRTUAL, interval) + + def cancel(): + signal.setitimer(signal.ITIMER_VIRTUAL, 0) + atexit.register(cancel) + + signal.signal(signal.SIGVTALRM, sample) + # deliver the first signal in INTERVAL seconds + signal.setitimer(signal.ITIMER_VIRTUAL, interval) diff --git a/common/setup.py b/common/setup.py new file mode 100644 index 0000000..43818fc --- /dev/null +++ b/common/setup.py @@ -0,0 +1,12 @@ +from setuptools import setup, find_packages + +setup( + name = "wubloader-common", + version = "0.0.0", + packages = find_packages(), + install_requires = [ + "gevent==1.5a2", + "monotonic", + "prometheus-client", + ], +)