You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

500 lines
18 KiB
Python

from collections import namedtuple
from unicodedata import name
import flask as flask
from datetime import datetime
from common import database
from gevent.pool import Pool
from psycopg2.extras import execute_values
app = flask.Flask('escher')
class Result:
def __init__(self, transcript=None, vst=None, chat=None):
self.transcript = transcript if transcript else []
self.vst = vst
self.chat = chat if chat else []
def __repr__(self) -> str:
return f'Result(transcript={self.transcript}, vst={self.vst}, chat={self.chat})'
@property
def start_time(self):
"""Compute the start time of the whole result."""
start_times = [self.vst.start_time] if self.vst else []
if self.transcript:
start_times.append(self.transcript[0].start_time)
if self.chat:
start_times.append(self.chat[0].pub_time)
return min(start_times)
@property
def end_time(self):
"""Compute the start time of the whole result."""
end_times = [self.vst.end_time] if self.vst else []
if self.transcript:
end_times.append(self.transcript[-1].end_time)
if self.chat:
end_times.append(self.chat[-1].pub_time)
return max(end_times)
@property
def weight(self):
return sum([cl.rank for cl in self.chat], 0) + \
10 * sum([tl.rank for tl in self.transcript], 0) + \
20 * (self.vst.rank if self.vst else 0)
def get_transcript(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
query = """
--sql
WITH q AS (
SELECT convert_query(%(ts_query)s)
),
relevant_lines AS (
(
SELECT id
FROM buscribe_transcriptions
WHERE to_tsvector('english', transcription_line) @@ (SELECT * FROM q)
)
UNION
(
SELECT line
FROM buscribe_verified_lines
WHERE to_tsvector('english', verified_line) @@ (SELECT * FROM q)
)
UNION
(
SELECT line
FROM buscribe_line_speakers
WHERE to_tsvector('english', speaker_name) @@ (SELECT * FROM q)
)
UNION
(
SELECT line
FROM buscribe_line_inferred_speakers
WHERE to_tsvector('english', speaker_name) @@ (SELECT * FROM q)
)
)
(
(
SELECT id,
start_time,
end_time,
null AS verifier,
names,
transcription_line,
ts_rank_cd(
coalesce(
to_tsvector('english', transcription_line),
''::tsvector
) || coalesce(
to_tsvector(array_to_string(names, ' ')),
''::tsvector
),
(SELECT * FROM q)
) AS rank,
ts_headline(
transcription_line,
(SELECT * FROM q),
'StartSel=''<span class=\"highlight\">'', StopSel=</span>'
) AS highlighted_text
-- transcription_json
FROM buscribe_transcriptions
LEFT OUTER JOIN (
SELECT line,
ARRAY(
SELECT speaker_name
FROM buscribe_line_inferred_speakers AS inner_speakers
WHERE inner_speakers.line = buscribe_line_inferred_speakers.line
) AS names
FROM buscribe_line_inferred_speakers
) AS inferred_speakers ON id = inferred_speakers.line
WHERE id IN (
SELECT id
FROM relevant_lines
)
AND start_time >= coalesce(%(start_time)s, '-infinity'::timestamp)
AND end_time <= coalesce(%(end_time)s, 'infinity'::timestamp)
)
UNION
(
SELECT buscribe_transcriptions.id AS id,
start_time,
end_time,
cverifier AS verifier,
names,
coalesce(
verifications.verified_line,
buscribe_transcriptions.transcription_line
) AS transcription_line,
ts_rank_cd(
coalesce(
setweight(to_tsvector('english', verified_line), 'C'),
to_tsvector(
'english',
buscribe_transcriptions.transcription_line
),
''::tsvector
) || coalesce(
setweight(to_tsvector(array_to_string(names, ' ')), 'C'),
''::tsvector
),
(SELECT * FROM q)
) AS rank,
ts_headline(
coalesce(
verifications.verified_line,
buscribe_transcriptions.transcription_line
),
(SELECT * FROM q),
'StartSel=''<span class=\"highlight\">'', StopSel=</span>'
) AS highlighted_text
-- null AS transcription_json
FROM buscribe_transcriptions
INNER JOIN (
SELECT *,
coalesce(relevant_verified.line, relevant_speakers.line) AS cline,
coalesce(
relevant_verified.verifier,
relevant_speakers.verifier
) AS cverifier
FROM (
SELECT *
FROM buscribe_verified_lines
WHERE line IN (
SELECT id
FROM relevant_lines
)
) AS relevant_verified
FULL OUTER JOIN (
SELECT line,
verifier,
ARRAY(
SELECT speaker_name
FROM buscribe_line_speakers AS inner_speakers
WHERE inner_speakers.line = buscribe_line_speakers.line
AND inner_speakers.verifier = buscribe_line_speakers.verifier
) AS names
FROM buscribe_line_speakers
WHERE line IN (
SELECT id
FROM relevant_lines
)
) AS relevant_speakers ON relevant_verified.line = relevant_speakers.line
AND relevant_speakers.verifier = relevant_verified.verifier
) AS verifications ON id = verifications.cline
WHERE start_time >= coalesce(%(start_time)s, '-infinity'::timestamp)
AND end_time <= coalesce(%(end_time)s, 'infinity'::timestamp)
)
)
ORDER BY
--rank DESC,
start_time;
"""
db_results = database.query(db_conn, query,
start_time=start_time if start_time is not None else '-infinity',
end_time=end_time if end_time is not None else 'infinity',
ts_query=ts_query)
# Number of messages
n_m = db_results.rowcount
# Get duration for message frequency calculation
bus_duration = database.query(db_conn,
"""
--sql
SELECT EXTRACT(EPOCH FROM max(end_time) - min(start_time)) AS bus_duration FROM buscribe_transcriptions;
"""
).fetchone().bus_duration
# Priors saying that an interesting event (p < 0.01) is when three messages per minute are posted, giving the
# equivalent average 13 messages per minute
a_p = 13.0
b_p = 60.0
# 0.02th quantile difference (cf. chat)
p = 0.02
l = b_p + bus_duration
a = a_p + n_m
# Lomax distribution is posterior predictive for exponential
message_duration_diff = l * ((1 - p)**-(1/a) - 1)
current_result = Result()
results = []
# print(message_duration_diff)
for transcript_line in db_results:
# Current result set is new
if not current_result.transcript:
current_result.transcript.append(transcript_line)
# New message is within window
elif (transcript_line.start_time - current_result.transcript[-1].end_time).total_seconds() <= message_duration_diff:
# print((transcript_line.start_time -
# current_result.transcript[-1].end_time).total_seconds())
current_result.transcript.append(transcript_line)
# New message is outside window
else:
# Always save the run (cf. chat; we save all lines)
results.append(current_result)
# Start new run
current_result = Result(transcript=[transcript_line])
results.append(current_result)
return results
def get_vst(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
query = """
--sql
SELECT
video_ranges[1].start AS start_time,
video_ranges[1].end AS end_time,
video_title AS title,
video_id AS id,
ts_rank_cd(
setweight(
to_tsvector('english'::regconfig, video_title),
'C'),
websearch_to_tsquery(%(ts_query)s)
) AS rank
FROM events
WHERE
video_ranges[1].start >= %(start_time)s AND
video_ranges[1].end <= %(end_time)s AND
array_length(video_ranges, 1) = 1 AND -- Only handling single-cut videos for now
to_tsvector('english'::regconfig, video_title) @@ websearch_to_tsquery(%(ts_query)s)
ORDER BY
--rank DESC,
start_time
LIMIT 100; -- Hard limit on result number
"""
results = database.query(db_conn, query,
start_time=start_time if start_time is not None else '-infinity',
end_time=end_time if end_time is not None else 'infinity',
ts_query=ts_query)
return [Result(vst=result) for result in results]
def get_chat(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
query = """
--sql
SELECT * FROM (
SELECT
pub_time,
content->'tags'->>'display-name' AS name,
content->'params'->>1 AS content,
ts_rank_cd(
setweight(
to_tsvector('english'::regconfig, content->'params'->>1),
'A'),
websearch_to_tsquery(%(ts_query)s)
) AS rank
FROM chat
WHERE
pub_time >= %(start_time)s AND
pub_time <= %(end_time)s AND
to_tsvector('english'::regconfig, content->'params'->>1) @@ websearch_to_tsquery(%(ts_query)s)
--ORDER BY
-- rank DESC
--LIMIT 100 -- Hard limit on result number
) AS chat ORDER BY pub_time; -- Re-sort by time for merging
"""
db_results = database.query(db_conn, query,
start_time=start_time if start_time is not None else '-infinity',
end_time=end_time if end_time is not None else 'infinity',
ts_query=ts_query)
# Number of messages
n_m = db_results.rowcount
# Get duration for message frequency calculation
bus_duration = database.query(db_conn,
"""
--sql
SELECT EXTRACT(EPOCH FROM max(pub_time) - min(pub_time)) AS bus_duration FROM chat;
"""
).fetchone().bus_duration
# Priors saying that an interesting event (p < 0.01) is when three messages per minute are posted, giving the
# equivalent average 13 messages per minute
a_p = 13.0
b_p = 60.0
# 0.01th quantile difference
p = 0.01
l = b_p + bus_duration
a = a_p + n_m
# Lomax distribution is posterior predictive for exponential
message_duration_diff = l * ((1 - p)**-(1/a) - 1)
current_result = Result()
results = []
for chat_line in db_results:
# Current result set is new
if not current_result.chat:
current_result.chat.append(chat_line)
# New message is within window
elif (chat_line.pub_time - current_result.chat[-1].pub_time).total_seconds() <= message_duration_diff:
current_result.chat.append(chat_line)
# New message is outside window
else:
# Current run has more than one message: save it
if len(current_result.chat) > 1:
results.append(current_result)
# Start new run
current_result = Result(chat=[chat_line])
if len(current_result.chat) > 1:
results.append(current_result)
return results
def load_results_data(db_conn, results):
"""
Replace chat and transcript with all entries in result's timeframe.
"""
# ggroup = Pool(size=30)
# results = ggroup.map(lambda result: load_result_data(db_manager, result), results)
result_timespans = [(i, result.start_time, result.end_time) for (i, result) in enumerate(results)]
# Clear lists so we can later insert new lines
for result in results:
result.chat = []
result.transcript = []
cur = db_conn.cursor()
execute_values(cur,
"""
--sql
WITH timespans (id, start_time, end_time) AS (VALUES %s)
SELECT
timespans.id,
pub_time,
content->'tags'->>'display-name' AS name,
content->'params'->>1 AS content FROM timespans JOIN chat ON (pub_time BETWEEN start_time AND end_time);
""",
result_timespans
)
for chat_line in cur:
results[chat_line.id].chat.append(chat_line)
execute_values(cur,
"""
--sql
WITH timespans (id, start_time, end_time) AS (VALUES %s)
SELECT timespans.id,
ARRAY(
SELECT row_to_json(transcriptions)
FROM (
SELECT buscribe_transcriptions.start_time,
buscribe_transcriptions.end_time,
ARRAY(
SELECT speaker_name
FROM buscribe_line_inferred_speakers
WHERE buscribe_line_inferred_speakers.line = buscribe_transcriptions.id
) as names,
buscribe_transcriptions.transcription_line
FROM buscribe_transcriptions
WHERE buscribe_transcriptions.start_time >= timespans.start_time
AND buscribe_transcriptions.start_time <= timespans.end_time
AND buscribe_transcriptions.end_time >= timespans.start_time
AND buscribe_transcriptions.end_time <= timespans.end_time
) AS transcriptions
) AS transcriptions
FROM timespans;
""",
result_timespans
)
TranscriptRecord = namedtuple("TranscriptRecord", cur.fetchone().transcriptions[0].keys())
cur.scroll(-1)
for result in cur:
results[result.id].transcript.extend([TranscriptRecord(**transcription) for transcription in result.transcriptions])
return results
def merge_results(transcript: list[Result], vst: list[Result], chat: list[Result], limit: int, offset = 0):
"""
Merge different types of results in order of importance.
First, merge anything that overlaps with a VST clip into the clip's result.
Second, merge any chat the overlaps with transcripts into the transcript.
Finally, append remaining chats.
"""
transcript_i = 0
chat_i = 0
# Merge transcript and chat into vst
for vst_result in vst:
while transcript_i < len(transcript) and transcript[transcript_i].start_time < vst_result.end_time:
if overlap(vst_result, transcript[transcript_i]):
vst_result.transcript.extend(
transcript.pop(transcript_i).transcript)
else:
transcript_i += 1
while chat_i < len(chat) and chat[chat_i].start_time < vst_result.end_time:
if overlap(vst_result, chat[chat_i]):
vst_result.chat.extend(chat.pop(chat_i).chat)
else:
chat_i += 1
chat_i = 0
# Merge chat into transcript
for transcript_result in transcript:
while chat_i < len(chat) and chat[chat_i].start_time < transcript_result.end_time:
if overlap(transcript_result, chat[chat_i]):
transcript_result.chat.extend(chat.pop(chat_i).chat)
else:
chat_i += 1
merged = transcript + vst + chat
merged.sort(key=lambda result: result.start_time)
merged.sort(key=lambda result: result.weight, reverse=True)
return merged[offset:min((offset + limit), len(merged))]
def overlap(result_a, result_b):
"""
A |---------|
B |---|
or
A |-----|
B |---|
or
A |---|
B |---|
"""
return result_b.start_time >= result_a.start_time and result_b.start_time <= result_a.end_time or \
result_b.end_time >= result_a.start_time and result_b.end_time <= result_a.end_time