trunk
HeNine 2 years ago
parent 60012b90e2
commit 94facb702e

164
api/.gitignore vendored

@ -0,0 +1,164 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
.escher_api_env
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Scratch script directory
/scratch

@ -0,0 +1,16 @@
import gevent.monkey
gevent.monkey.patch_all()
import logging
import os
import argh
from escher_api.main import main
LOG_FORMAT = "[%(asctime)s] %(levelname)8s %(name)s(%(module)s:%(lineno)d): %(message)s"
level = os.environ.get('WUBLOADER_LOG_LEVEL', 'INFO').upper()
logging.basicConfig(level=level, format=LOG_FORMAT)
argh.dispatch_command(main)

@ -0,0 +1,427 @@
import flask as flask
from datetime import datetime
from common import database
app = flask.Flask('escher')
class Result:
def __init__(self, transcript=None, vst=None, chat=None):
self.transcript = transcript if transcript else []
self.vst = vst
self.chat = chat if chat else []
def __repr__(self) -> str:
return f'Result(transcript={self.transcript}, vst={self.vst}, chat={self.chat})'
@property
def start_time(self):
"""Compute the start time of the whole result."""
start_times = [self.vst.start_time] if self.vst else []
if self.transcript:
start_times.append(self.transcript[0].start_time)
if self.chat:
start_times.append(self.chat[0].pub_time)
return min(start_times)
@property
def end_time(self):
"""Compute the start time of the whole result."""
end_times = [self.vst.end_time] if self.vst else []
if self.transcript:
end_times.append(self.transcript[-1].end_time)
if self.chat:
end_times.append(self.chat[-1].pub_time)
return max(end_times)
@property
def weight(self):
return sum([cl.rank for cl in self.chat], 0) + \
10 * sum([tl.rank for tl in self.transcript], 0) + \
20 * (self.vst.rank if self.vst else 0)
def get_transcript(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
query = """
--sql
WITH q AS (
SELECT convert_query(%(ts_query)s)
),
relevant_lines AS (
(
SELECT id
FROM buscribe_transcriptions
WHERE to_tsvector('english', transcription_line) @@ (SELECT * FROM q)
)
UNION
(
SELECT line
FROM buscribe_verified_lines
WHERE to_tsvector('english', verified_line) @@ (SELECT * FROM q)
)
UNION
(
SELECT line
FROM buscribe_line_speakers
WHERE to_tsvector('english', speaker_name) @@ (SELECT * FROM q)
)
UNION
(
SELECT line
FROM buscribe_line_inferred_speakers
WHERE to_tsvector('english', speaker_name) @@ (SELECT * FROM q)
)
)
(
(
SELECT id,
start_time,
end_time,
null AS verifier,
names,
transcription_line,
ts_rank_cd(
coalesce(
to_tsvector('english', transcription_line),
''::tsvector
) || coalesce(
to_tsvector(array_to_string(names, ' ')),
''::tsvector
),
(SELECT * FROM q)
) AS rank,
ts_headline(
transcription_line,
(SELECT * FROM q),
'StartSel=''<span class=\"highlight\">'', StopSel=</span>'
) AS highlighted_text
-- transcription_json
FROM buscribe_transcriptions
LEFT OUTER JOIN (
SELECT line,
ARRAY(
SELECT speaker_name
FROM buscribe_line_inferred_speakers AS inner_speakers
WHERE inner_speakers.line = buscribe_line_inferred_speakers.line
) AS names
FROM buscribe_line_inferred_speakers
) AS inferred_speakers ON id = inferred_speakers.line
WHERE id IN (
SELECT id
FROM relevant_lines
)
AND start_time >= coalesce(%(start_time)s, '-infinity'::timestamp)
AND end_time <= coalesce(%(end_time)s, 'infinity'::timestamp)
)
UNION
(
SELECT buscribe_transcriptions.id AS id,
start_time,
end_time,
cverifier AS verifier,
names,
coalesce(
verifications.verified_line,
buscribe_transcriptions.transcription_line
) AS transcription_line,
ts_rank_cd(
coalesce(
setweight(to_tsvector('english', verified_line), 'C'),
to_tsvector(
'english',
buscribe_transcriptions.transcription_line
),
''::tsvector
) || coalesce(
setweight(to_tsvector(array_to_string(names, ' ')), 'C'),
''::tsvector
),
(SELECT * FROM q)
) AS rank,
ts_headline(
coalesce(
verifications.verified_line,
buscribe_transcriptions.transcription_line
),
(SELECT * FROM q),
'StartSel=''<span class=\"highlight\">'', StopSel=</span>'
) AS highlighted_text
-- null AS transcription_json
FROM buscribe_transcriptions
INNER JOIN (
SELECT *,
coalesce(relevant_verified.line, relevant_speakers.line) AS cline,
coalesce(
relevant_verified.verifier,
relevant_speakers.verifier
) AS cverifier
FROM (
SELECT *
FROM buscribe_verified_lines
WHERE line IN (
SELECT id
FROM relevant_lines
)
) AS relevant_verified
FULL OUTER JOIN (
SELECT line,
verifier,
ARRAY(
SELECT speaker_name
FROM buscribe_line_speakers AS inner_speakers
WHERE inner_speakers.line = buscribe_line_speakers.line
AND inner_speakers.verifier = buscribe_line_speakers.verifier
) AS names
FROM buscribe_line_speakers
WHERE line IN (
SELECT id
FROM relevant_lines
)
) AS relevant_speakers ON relevant_verified.line = relevant_speakers.line
AND relevant_speakers.verifier = relevant_verified.verifier
) AS verifications ON id = verifications.cline
WHERE start_time >= coalesce(%(start_time)s, '-infinity'::timestamp)
AND end_time <= coalesce(%(end_time)s, 'infinity'::timestamp)
)
)
ORDER BY
--rank DESC,
start_time;
"""
db_results = database.query(db_conn, query,
start_time=start_time if start_time is not None else '-infinity',
end_time=end_time if end_time is not None else 'infinity',
ts_query=ts_query)
# Number of messages
n_m = db_results.rowcount
# Get duration for message frequency calculation
bus_duration = database.query(db_conn,
"""
--sql
SELECT EXTRACT(EPOCH FROM max(end_time) - min(start_time)) AS bus_duration FROM buscribe_transcriptions;
"""
).fetchone().bus_duration
# Priors saying that an interesting event (p < 0.01) is when three messages per minute are posted, giving the
# equivalent average 13 messages per minute
a_p = 13.0
b_p = 60.0
# 0.02th quantile difference (cf. chat)
p = 0.02
l = b_p + bus_duration
a = a_p + n_m
# Lomax distribution is posterior predictive for exponential
message_duration_diff = l * ((1 - p)**-(1/a) - 1)
current_result = Result()
results = []
print(message_duration_diff)
for transcript_line in db_results:
# Current result set is new
if not current_result.transcript:
current_result.transcript.append(transcript_line)
# New message is within window
elif (transcript_line.start_time - current_result.transcript[-1].end_time).total_seconds() <= message_duration_diff:
print((transcript_line.start_time -
current_result.transcript[-1].end_time).total_seconds())
current_result.transcript.append(transcript_line)
# New message is outside window
else:
# Always save the run (cf. chat; we save all lines)
results.append(current_result)
# Start new run
current_result = Result(transcript=[transcript_line])
results.append(current_result)
return results
def get_vst(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
query = """
--sql
SELECT
video_ranges[1].start AS start_time,
video_ranges[1].end AS end_time,
video_title AS title,
video_id AS id,
ts_rank_cd(
setweight(
to_tsvector('english'::regconfig, video_title),
'C'),
websearch_to_tsquery(%(ts_query)s)
) AS rank
FROM events
WHERE
video_ranges[1].start >= %(start_time)s AND
video_ranges[1].end <= %(end_time)s AND
array_length(video_ranges, 1) = 1 AND -- Only handling single-cut videos for now
to_tsvector('english'::regconfig, video_title) @@ websearch_to_tsquery(%(ts_query)s)
ORDER BY
--rank DESC,
start_time
LIMIT 100; -- Hard limit on result number
"""
results = database.query(db_conn, query,
start_time=start_time if start_time is not None else '-infinity',
end_time=end_time if end_time is not None else 'infinity',
ts_query=ts_query)
return [Result(vst=result) for result in results]
def get_chat(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
query = """
--sql
SELECT * FROM (
SELECT
pub_time,
content->'tags'->>'display-name' AS name,
content->'params'->>1 AS content,
ts_rank_cd(
setweight(
to_tsvector('english'::regconfig, content->'params'->>1),
'A'),
websearch_to_tsquery(%(ts_query)s)
) AS rank
FROM chat
WHERE
pub_time >= %(start_time)s AND
pub_time <= %(end_time)s AND
to_tsvector('english'::regconfig, content->'params'->>1) @@ websearch_to_tsquery(%(ts_query)s)
--ORDER BY
-- rank DESC
--LIMIT 100 -- Hard limit on result number
) AS chat ORDER BY pub_time; -- Re-sort by time for merging
"""
db_results = database.query(db_conn, query,
start_time=start_time if start_time is not None else '-infinity',
end_time=end_time if end_time is not None else 'infinity',
ts_query=ts_query)
# Number of messages
n_m = db_results.rowcount
# Get duration for message frequency calculation
bus_duration = database.query(db_conn,
"""
--sql
SELECT EXTRACT(EPOCH FROM max(pub_time) - min(pub_time)) AS bus_duration FROM chat;
"""
).fetchone().bus_duration
# Priors saying that an interesting event (p < 0.01) is when three messages per minute are posted, giving the
# equivalent average 13 messages per minute
a_p = 13.0
b_p = 60.0
# 0.01th quantile difference
p = 0.01
l = b_p + bus_duration
a = a_p + n_m
# Lomax distribution is posterior predictive for exponential
message_duration_diff = l * ((1 - p)**-(1/a) - 1)
current_result = Result()
results = []
for chat_line in db_results:
# Current result set is new
if not current_result.chat:
current_result.chat.append(chat_line)
# New message is within window
elif (chat_line.pub_time - current_result.chat[-1].pub_time).total_seconds() <= message_duration_diff:
current_result.chat.append(chat_line)
# New message is outside window
else:
# Current run has more than one message: save it
if len(current_result.chat) > 1:
results.append(current_result)
# Start new run
current_result = Result(chat=[chat_line])
if len(current_result.chat) > 1:
results.append(current_result)
return results
def load_result_data(result):
pass
def merge_results(transcript: list[Result], vst: list[Result], chat: list[Result]):
"""
Merge different types of results in order of importance.
First, merge anything that overlaps with a VST clip into the clip's result.
Second, merge any chat the overlaps with transcripts into the transcript.
Finally, append remaining chats.
"""
transcript_i = 0
chat_i = 0
# Merge transcript and chat into vst
for vst_result in vst:
while transcript_i < len(transcript) and transcript[transcript_i].start_time < vst_result.end_time:
if overlap(vst_result, transcript[transcript_i]):
vst_result.transcript.extend(
transcript.pop(transcript_i).transcript)
else:
transcript_i += 1
# print(vst_result)
while chat_i < len(chat) and chat[chat_i].start_time < vst_result.end_time:
# print(vst_result)
if overlap(vst_result, chat[chat_i]):
vst_result.chat.extend(chat.pop(chat_i).chat)
# print(vst_result)
else:
chat_i += 1
chat_i = 0
# Merge chat into transcript
for transcript_result in transcript:
while chat_i < len(chat) and chat[chat_i].start_time < transcript_result.end_time:
if overlap(transcript_result, chat[chat_i]):
transcript_result.chat.extend(chat.pop(chat_i).chat)
else:
chat_i += 1
merged = transcript + vst + chat
merged.sort(key=lambda result: result.start_time)
merged.sort(key=lambda result: result.weight, reverse=True)
return merged
def overlap(result_a, result_b):
"""
A |---------|
B |---|
or
A |-----|
B |---|
or
A |---|
B |---|
"""
return result_b.start_time >= result_a.start_time and result_b.start_time <= result_a.end_time or \
result_b.end_time >= result_a.start_time and result_b.end_time <= result_a.end_time

@ -0,0 +1,83 @@
import logging
import os
from time import sleep
import argh
import gevent.event
from common import dateutil
from common.database import DBManager
from dateutil.parser import ParserError
from gevent import signal
from gevent.pywsgi import WSGIServer
from escher_api.escher import app
def cors(app):
"""WSGI middleware that sets CORS headers"""
HEADERS = [
("Access-Control-Allow-Credentials", "false"),
("Access-Control-Allow-Headers", "*"),
("Access-Control-Allow-Methods", "GET,POST,HEAD"),
("Access-Control-Allow-Origin", "*"),
("Access-Control-Max-Age", "86400"),
]
def handle(environ, start_response):
def _start_response(status, headers, exc_info=None):
headers += HEADERS
return start_response(status, headers, exc_info)
return app(environ, _start_response)
return handle
def servelet(server):
logging.info('Starting WSGI server.')
server.serve_forever()
@argh.arg('channel',
help="Twitch channel to transcribe.")
@argh.arg('--host',
help='Address or socket server will listen to. Default is 0.0.0.0 (everything on the local machine).')
@argh.arg('--port',
help='Port server will listen on. Default is 8004.')
@argh.arg('--database',
help='Postgres connection string, which is either a space-separated list of key=value pairs, or a URI like: '
'postgresql://USER:PASSWORD@HOST/DBNAME?KEY=VALUE')
@argh.arg('--bustime-start',
help='The start time in UTC for the event, for UTC-Bustime conversion')
@argh.arg('--base-dir',
help='Directory from which segments will be grabbed. Default is current working directory.')
def main(channel, database="", host='0.0.0.0', port=8010, bustime_start=None, base_dir=None):
if bustime_start is None:
logging.error("Missing --bustime-start!")
exit(1)
server = WSGIServer((host, port), cors(app))
try:
app.bustime_start = dateutil.parse(bustime_start)
except ParserError:
logging.error("Invalid --bustime-start!")
exit(1)
app.segments_dir = os.path.join(base_dir, channel, "source")
app.db_manager = DBManager(dsn=database)
stopping = gevent.event.Event()
def stop():
logging.info("Shutting down")
stopping.set()
gevent.signal_handler(signal.SIGTERM, stop)
serve = gevent.spawn(servelet, server)
# Wait for either the stop signal or the server to oops out.
gevent.wait([serve, stopping], count=1)
server.stop()
serve.get() # Wait for server to shut down and/or re-raise if serve_forever() errored
logging.info("Gracefully shut down")

@ -0,0 +1,16 @@
from setuptools import setup, find_packages
setup(
name = "escher-api",
version = "0.0.0",
packages = find_packages(),
install_requires = [
"argh",
"python-dateutil",
"flask",
"gevent",
"monotonic",
"prometheus-client",
"wubloader-common",
],
)
Loading…
Cancel
Save