From 94facb702e05417921ea083e3943b0017e6f8651 Mon Sep 17 00:00:00 2001 From: HeNine <> Date: Mon, 8 Aug 2022 13:52:41 +0200 Subject: [PATCH] api init --- api/.gitignore | 164 ++++++++++++++ api/escher_api/__init__.py | 0 api/escher_api/__main__.py | 16 ++ api/escher_api/escher.py | 427 +++++++++++++++++++++++++++++++++++++ api/escher_api/main.py | 83 +++++++ api/setup.py | 16 ++ 6 files changed, 706 insertions(+) create mode 100644 api/.gitignore create mode 100644 api/escher_api/__init__.py create mode 100644 api/escher_api/__main__.py create mode 100644 api/escher_api/escher.py create mode 100644 api/escher_api/main.py create mode 100644 api/setup.py diff --git a/api/.gitignore b/api/.gitignore new file mode 100644 index 0000000..1e825b5 --- /dev/null +++ b/api/.gitignore @@ -0,0 +1,164 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +.escher_api_env +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Scratch script directory +/scratch \ No newline at end of file diff --git a/api/escher_api/__init__.py b/api/escher_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/escher_api/__main__.py b/api/escher_api/__main__.py new file mode 100644 index 0000000..486fc1a --- /dev/null +++ b/api/escher_api/__main__.py @@ -0,0 +1,16 @@ + +import gevent.monkey +gevent.monkey.patch_all() + +import logging +import os + +import argh + +from escher_api.main import main + +LOG_FORMAT = "[%(asctime)s] %(levelname)8s %(name)s(%(module)s:%(lineno)d): %(message)s" + +level = os.environ.get('WUBLOADER_LOG_LEVEL', 'INFO').upper() +logging.basicConfig(level=level, format=LOG_FORMAT) +argh.dispatch_command(main) \ No newline at end of file diff --git a/api/escher_api/escher.py b/api/escher_api/escher.py new file mode 100644 index 0000000..6e29449 --- /dev/null +++ b/api/escher_api/escher.py @@ -0,0 +1,427 @@ + +import flask as flask +from datetime import datetime +from common import database + +app = flask.Flask('escher') + + +class Result: + def __init__(self, transcript=None, vst=None, chat=None): + self.transcript = transcript if transcript else [] + self.vst = vst + self.chat = chat if chat else [] + + def __repr__(self) -> str: + return f'Result(transcript={self.transcript}, vst={self.vst}, chat={self.chat})' + + @property + def start_time(self): + """Compute the start time of the whole result.""" + start_times = [self.vst.start_time] if self.vst else [] + + if self.transcript: + start_times.append(self.transcript[0].start_time) + if self.chat: + start_times.append(self.chat[0].pub_time) + + return min(start_times) + + @property + def end_time(self): + """Compute the start time of the whole result.""" + end_times = [self.vst.end_time] if self.vst else [] + + if self.transcript: + end_times.append(self.transcript[-1].end_time) + if self.chat: + end_times.append(self.chat[-1].pub_time) + + return max(end_times) + + @property + def weight(self): + return sum([cl.rank for cl in self.chat], 0) + \ + 10 * sum([tl.rank for tl in self.transcript], 0) + \ + 20 * (self.vst.rank if self.vst else 0) + + +def get_transcript(db_conn, ts_query, start_time="-infinity", end_time="infinity"): + query = """ + --sql + WITH q AS ( + SELECT convert_query(%(ts_query)s) + ), + relevant_lines AS ( + ( + SELECT id + FROM buscribe_transcriptions + WHERE to_tsvector('english', transcription_line) @@ (SELECT * FROM q) + ) + UNION + ( + SELECT line + FROM buscribe_verified_lines + WHERE to_tsvector('english', verified_line) @@ (SELECT * FROM q) + ) + UNION + ( + SELECT line + FROM buscribe_line_speakers + WHERE to_tsvector('english', speaker_name) @@ (SELECT * FROM q) + ) + UNION + ( + SELECT line + FROM buscribe_line_inferred_speakers + WHERE to_tsvector('english', speaker_name) @@ (SELECT * FROM q) + ) + ) + ( + ( + SELECT id, + start_time, + end_time, + null AS verifier, + names, + transcription_line, + ts_rank_cd( + coalesce( + to_tsvector('english', transcription_line), + ''::tsvector + ) || coalesce( + to_tsvector(array_to_string(names, ' ')), + ''::tsvector + ), + (SELECT * FROM q) + ) AS rank, + ts_headline( + transcription_line, + (SELECT * FROM q), + 'StartSel='''', StopSel=' + ) AS highlighted_text + -- transcription_json + FROM buscribe_transcriptions + LEFT OUTER JOIN ( + SELECT line, + ARRAY( + SELECT speaker_name + FROM buscribe_line_inferred_speakers AS inner_speakers + WHERE inner_speakers.line = buscribe_line_inferred_speakers.line + ) AS names + FROM buscribe_line_inferred_speakers + ) AS inferred_speakers ON id = inferred_speakers.line + WHERE id IN ( + SELECT id + FROM relevant_lines + ) + AND start_time >= coalesce(%(start_time)s, '-infinity'::timestamp) + AND end_time <= coalesce(%(end_time)s, 'infinity'::timestamp) + ) + UNION + ( + SELECT buscribe_transcriptions.id AS id, + start_time, + end_time, + cverifier AS verifier, + names, + coalesce( + verifications.verified_line, + buscribe_transcriptions.transcription_line + ) AS transcription_line, + ts_rank_cd( + coalesce( + setweight(to_tsvector('english', verified_line), 'C'), + to_tsvector( + 'english', + buscribe_transcriptions.transcription_line + ), + ''::tsvector + ) || coalesce( + setweight(to_tsvector(array_to_string(names, ' ')), 'C'), + ''::tsvector + ), + (SELECT * FROM q) + ) AS rank, + ts_headline( + coalesce( + verifications.verified_line, + buscribe_transcriptions.transcription_line + ), + (SELECT * FROM q), + 'StartSel='''', StopSel=' + ) AS highlighted_text + -- null AS transcription_json + FROM buscribe_transcriptions + INNER JOIN ( + SELECT *, + coalesce(relevant_verified.line, relevant_speakers.line) AS cline, + coalesce( + relevant_verified.verifier, + relevant_speakers.verifier + ) AS cverifier + FROM ( + SELECT * + FROM buscribe_verified_lines + WHERE line IN ( + SELECT id + FROM relevant_lines + ) + ) AS relevant_verified + FULL OUTER JOIN ( + SELECT line, + verifier, + ARRAY( + SELECT speaker_name + FROM buscribe_line_speakers AS inner_speakers + WHERE inner_speakers.line = buscribe_line_speakers.line + AND inner_speakers.verifier = buscribe_line_speakers.verifier + ) AS names + FROM buscribe_line_speakers + WHERE line IN ( + SELECT id + FROM relevant_lines + ) + ) AS relevant_speakers ON relevant_verified.line = relevant_speakers.line + AND relevant_speakers.verifier = relevant_verified.verifier + ) AS verifications ON id = verifications.cline + WHERE start_time >= coalesce(%(start_time)s, '-infinity'::timestamp) + AND end_time <= coalesce(%(end_time)s, 'infinity'::timestamp) + ) + ) + ORDER BY + --rank DESC, + start_time; + """ + + db_results = database.query(db_conn, query, + start_time=start_time if start_time is not None else '-infinity', + end_time=end_time if end_time is not None else 'infinity', + ts_query=ts_query) + + # Number of messages + n_m = db_results.rowcount + + # Get duration for message frequency calculation + bus_duration = database.query(db_conn, + """ + --sql + SELECT EXTRACT(EPOCH FROM max(end_time) - min(start_time)) AS bus_duration FROM buscribe_transcriptions; + """ + ).fetchone().bus_duration + + # Priors saying that an interesting event (p < 0.01) is when three messages per minute are posted, giving the + # equivalent average 13 messages per minute + a_p = 13.0 + b_p = 60.0 + + # 0.02th quantile difference (cf. chat) + p = 0.02 + l = b_p + bus_duration + a = a_p + n_m + # Lomax distribution is posterior predictive for exponential + message_duration_diff = l * ((1 - p)**-(1/a) - 1) + current_result = Result() + results = [] + print(message_duration_diff) + for transcript_line in db_results: + + # Current result set is new + if not current_result.transcript: + current_result.transcript.append(transcript_line) + # New message is within window + elif (transcript_line.start_time - current_result.transcript[-1].end_time).total_seconds() <= message_duration_diff: + print((transcript_line.start_time - + current_result.transcript[-1].end_time).total_seconds()) + current_result.transcript.append(transcript_line) + # New message is outside window + else: + # Always save the run (cf. chat; we save all lines) + results.append(current_result) + # Start new run + current_result = Result(transcript=[transcript_line]) + + results.append(current_result) + + return results + + +def get_vst(db_conn, ts_query, start_time="-infinity", end_time="infinity"): + query = """ + --sql + SELECT + video_ranges[1].start AS start_time, + video_ranges[1].end AS end_time, + video_title AS title, + video_id AS id, + ts_rank_cd( + setweight( + to_tsvector('english'::regconfig, video_title), + 'C'), + websearch_to_tsquery(%(ts_query)s) + ) AS rank + FROM events + WHERE + video_ranges[1].start >= %(start_time)s AND + video_ranges[1].end <= %(end_time)s AND + array_length(video_ranges, 1) = 1 AND -- Only handling single-cut videos for now + to_tsvector('english'::regconfig, video_title) @@ websearch_to_tsquery(%(ts_query)s) + ORDER BY + --rank DESC, + start_time + LIMIT 100; -- Hard limit on result number + """ + + results = database.query(db_conn, query, + start_time=start_time if start_time is not None else '-infinity', + end_time=end_time if end_time is not None else 'infinity', + ts_query=ts_query) + + return [Result(vst=result) for result in results] + + +def get_chat(db_conn, ts_query, start_time="-infinity", end_time="infinity"): + query = """ + --sql + SELECT * FROM ( + SELECT + pub_time, + content->'tags'->>'display-name' AS name, + content->'params'->>1 AS content, + ts_rank_cd( + setweight( + to_tsvector('english'::regconfig, content->'params'->>1), + 'A'), + websearch_to_tsquery(%(ts_query)s) + ) AS rank + FROM chat + WHERE + pub_time >= %(start_time)s AND + pub_time <= %(end_time)s AND + to_tsvector('english'::regconfig, content->'params'->>1) @@ websearch_to_tsquery(%(ts_query)s) + --ORDER BY + -- rank DESC + --LIMIT 100 -- Hard limit on result number + ) AS chat ORDER BY pub_time; -- Re-sort by time for merging + """ + + db_results = database.query(db_conn, query, + start_time=start_time if start_time is not None else '-infinity', + end_time=end_time if end_time is not None else 'infinity', + ts_query=ts_query) + + # Number of messages + n_m = db_results.rowcount + + # Get duration for message frequency calculation + bus_duration = database.query(db_conn, + """ + --sql + SELECT EXTRACT(EPOCH FROM max(pub_time) - min(pub_time)) AS bus_duration FROM chat; + """ + ).fetchone().bus_duration + + # Priors saying that an interesting event (p < 0.01) is when three messages per minute are posted, giving the + # equivalent average 13 messages per minute + a_p = 13.0 + b_p = 60.0 + + # 0.01th quantile difference + p = 0.01 + l = b_p + bus_duration + a = a_p + n_m + # Lomax distribution is posterior predictive for exponential + message_duration_diff = l * ((1 - p)**-(1/a) - 1) + + current_result = Result() + results = [] + for chat_line in db_results: + # Current result set is new + if not current_result.chat: + current_result.chat.append(chat_line) + # New message is within window + elif (chat_line.pub_time - current_result.chat[-1].pub_time).total_seconds() <= message_duration_diff: + current_result.chat.append(chat_line) + # New message is outside window + else: + # Current run has more than one message: save it + if len(current_result.chat) > 1: + results.append(current_result) + # Start new run + current_result = Result(chat=[chat_line]) + + if len(current_result.chat) > 1: + results.append(current_result) + + return results + + +def load_result_data(result): + pass + + +def merge_results(transcript: list[Result], vst: list[Result], chat: list[Result]): + """ + Merge different types of results in order of importance. + + First, merge anything that overlaps with a VST clip into the clip's result. + + Second, merge any chat the overlaps with transcripts into the transcript. + + Finally, append remaining chats. + """ + + transcript_i = 0 + chat_i = 0 + + # Merge transcript and chat into vst + for vst_result in vst: + while transcript_i < len(transcript) and transcript[transcript_i].start_time < vst_result.end_time: + if overlap(vst_result, transcript[transcript_i]): + vst_result.transcript.extend( + transcript.pop(transcript_i).transcript) + else: + transcript_i += 1 + + # print(vst_result) + while chat_i < len(chat) and chat[chat_i].start_time < vst_result.end_time: + # print(vst_result) + if overlap(vst_result, chat[chat_i]): + vst_result.chat.extend(chat.pop(chat_i).chat) + # print(vst_result) + else: + chat_i += 1 + + chat_i = 0 + + # Merge chat into transcript + for transcript_result in transcript: + while chat_i < len(chat) and chat[chat_i].start_time < transcript_result.end_time: + if overlap(transcript_result, chat[chat_i]): + transcript_result.chat.extend(chat.pop(chat_i).chat) + else: + chat_i += 1 + + merged = transcript + vst + chat + merged.sort(key=lambda result: result.start_time) + merged.sort(key=lambda result: result.weight, reverse=True) + return merged + + +def overlap(result_a, result_b): + """ + A |---------| + B |---| + + or + + A |-----| + B |---| + + or + + A |---| + B |---| + """ + return result_b.start_time >= result_a.start_time and result_b.start_time <= result_a.end_time or \ + result_b.end_time >= result_a.start_time and result_b.end_time <= result_a.end_time diff --git a/api/escher_api/main.py b/api/escher_api/main.py new file mode 100644 index 0000000..846e134 --- /dev/null +++ b/api/escher_api/main.py @@ -0,0 +1,83 @@ +import logging +import os +from time import sleep + +import argh +import gevent.event +from common import dateutil +from common.database import DBManager +from dateutil.parser import ParserError +from gevent import signal +from gevent.pywsgi import WSGIServer + +from escher_api.escher import app + + +def cors(app): + """WSGI middleware that sets CORS headers""" + HEADERS = [ + ("Access-Control-Allow-Credentials", "false"), + ("Access-Control-Allow-Headers", "*"), + ("Access-Control-Allow-Methods", "GET,POST,HEAD"), + ("Access-Control-Allow-Origin", "*"), + ("Access-Control-Max-Age", "86400"), + ] + def handle(environ, start_response): + def _start_response(status, headers, exc_info=None): + headers += HEADERS + return start_response(status, headers, exc_info) + return app(environ, _start_response) + return handle + + +def servelet(server): + logging.info('Starting WSGI server.') + server.serve_forever() + +@argh.arg('channel', + help="Twitch channel to transcribe.") +@argh.arg('--host', + help='Address or socket server will listen to. Default is 0.0.0.0 (everything on the local machine).') +@argh.arg('--port', + help='Port server will listen on. Default is 8004.') +@argh.arg('--database', + help='Postgres connection string, which is either a space-separated list of key=value pairs, or a URI like: ' + 'postgresql://USER:PASSWORD@HOST/DBNAME?KEY=VALUE') +@argh.arg('--bustime-start', + help='The start time in UTC for the event, for UTC-Bustime conversion') +@argh.arg('--base-dir', + help='Directory from which segments will be grabbed. Default is current working directory.') +def main(channel, database="", host='0.0.0.0', port=8010, bustime_start=None, base_dir=None): + if bustime_start is None: + logging.error("Missing --bustime-start!") + exit(1) + + server = WSGIServer((host, port), cors(app)) + + try: + app.bustime_start = dateutil.parse(bustime_start) + except ParserError: + logging.error("Invalid --bustime-start!") + exit(1) + + app.segments_dir = os.path.join(base_dir, channel, "source") + + app.db_manager = DBManager(dsn=database) + + stopping = gevent.event.Event() + + def stop(): + logging.info("Shutting down") + stopping.set() + + gevent.signal_handler(signal.SIGTERM, stop) + + serve = gevent.spawn(servelet, server) + + # Wait for either the stop signal or the server to oops out. + gevent.wait([serve, stopping], count=1) + + server.stop() + serve.get() # Wait for server to shut down and/or re-raise if serve_forever() errored + + logging.info("Gracefully shut down") \ No newline at end of file diff --git a/api/setup.py b/api/setup.py new file mode 100644 index 0000000..be320d6 --- /dev/null +++ b/api/setup.py @@ -0,0 +1,16 @@ +from setuptools import setup, find_packages + +setup( + name = "escher-api", + version = "0.0.0", + packages = find_packages(), + install_requires = [ + "argh", + "python-dateutil", + "flask", + "gevent", + "monotonic", + "prometheus-client", + "wubloader-common", + ], +)