You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
500 lines
18 KiB
Python
500 lines
18 KiB
Python
|
|
from collections import namedtuple
|
|
from unicodedata import name
|
|
import flask as flask
|
|
from datetime import datetime
|
|
from common import database
|
|
|
|
from gevent.pool import Pool
|
|
|
|
from psycopg2.extras import execute_values
|
|
|
|
app = flask.Flask('escher')
|
|
|
|
|
|
class Result:
|
|
def __init__(self, transcript=None, vst=None, chat=None):
|
|
self.transcript = transcript if transcript else []
|
|
self.vst = vst
|
|
self.chat = chat if chat else []
|
|
|
|
def __repr__(self) -> str:
|
|
return f'Result(transcript={self.transcript}, vst={self.vst}, chat={self.chat})'
|
|
|
|
@property
|
|
def start_time(self):
|
|
"""Compute the start time of the whole result."""
|
|
start_times = [self.vst.start_time] if self.vst else []
|
|
|
|
if self.transcript:
|
|
start_times.append(self.transcript[0].start_time)
|
|
if self.chat:
|
|
start_times.append(self.chat[0].pub_time)
|
|
|
|
return min(start_times)
|
|
|
|
@property
|
|
def end_time(self):
|
|
"""Compute the start time of the whole result."""
|
|
end_times = [self.vst.end_time] if self.vst else []
|
|
|
|
if self.transcript:
|
|
end_times.append(self.transcript[-1].end_time)
|
|
if self.chat:
|
|
end_times.append(self.chat[-1].pub_time)
|
|
|
|
return max(end_times)
|
|
|
|
@property
|
|
def weight(self):
|
|
return sum([cl.rank for cl in self.chat], 0) + \
|
|
10 * sum([tl.rank for tl in self.transcript], 0) + \
|
|
20 * (self.vst.rank if self.vst else 0)
|
|
|
|
|
|
def get_transcript(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
|
|
query = """
|
|
--sql
|
|
WITH q AS (
|
|
SELECT convert_query(%(ts_query)s)
|
|
),
|
|
relevant_lines AS (
|
|
(
|
|
SELECT id
|
|
FROM buscribe_transcriptions
|
|
WHERE to_tsvector('english', transcription_line) @@ (SELECT * FROM q)
|
|
)
|
|
UNION
|
|
(
|
|
SELECT line
|
|
FROM buscribe_verified_lines
|
|
WHERE to_tsvector('english', verified_line) @@ (SELECT * FROM q)
|
|
)
|
|
UNION
|
|
(
|
|
SELECT line
|
|
FROM buscribe_line_speakers
|
|
WHERE to_tsvector('english', speaker_name) @@ (SELECT * FROM q)
|
|
)
|
|
UNION
|
|
(
|
|
SELECT line
|
|
FROM buscribe_line_inferred_speakers
|
|
WHERE to_tsvector('english', speaker_name) @@ (SELECT * FROM q)
|
|
)
|
|
)
|
|
(
|
|
(
|
|
SELECT id,
|
|
start_time,
|
|
end_time,
|
|
null AS verifier,
|
|
names,
|
|
transcription_line,
|
|
ts_rank_cd(
|
|
coalesce(
|
|
to_tsvector('english', transcription_line),
|
|
''::tsvector
|
|
) || coalesce(
|
|
to_tsvector(array_to_string(names, ' ')),
|
|
''::tsvector
|
|
),
|
|
(SELECT * FROM q)
|
|
) AS rank,
|
|
ts_headline(
|
|
transcription_line,
|
|
(SELECT * FROM q),
|
|
'StartSel=''<span class=\"highlight\">'', StopSel=</span>'
|
|
) AS highlighted_text
|
|
-- transcription_json
|
|
FROM buscribe_transcriptions
|
|
LEFT OUTER JOIN (
|
|
SELECT line,
|
|
ARRAY(
|
|
SELECT speaker_name
|
|
FROM buscribe_line_inferred_speakers AS inner_speakers
|
|
WHERE inner_speakers.line = buscribe_line_inferred_speakers.line
|
|
) AS names
|
|
FROM buscribe_line_inferred_speakers
|
|
) AS inferred_speakers ON id = inferred_speakers.line
|
|
WHERE id IN (
|
|
SELECT id
|
|
FROM relevant_lines
|
|
)
|
|
AND start_time >= coalesce(%(start_time)s, '-infinity'::timestamp)
|
|
AND end_time <= coalesce(%(end_time)s, 'infinity'::timestamp)
|
|
)
|
|
UNION
|
|
(
|
|
SELECT buscribe_transcriptions.id AS id,
|
|
start_time,
|
|
end_time,
|
|
cverifier AS verifier,
|
|
names,
|
|
coalesce(
|
|
verifications.verified_line,
|
|
buscribe_transcriptions.transcription_line
|
|
) AS transcription_line,
|
|
ts_rank_cd(
|
|
coalesce(
|
|
setweight(to_tsvector('english', verified_line), 'C'),
|
|
to_tsvector(
|
|
'english',
|
|
buscribe_transcriptions.transcription_line
|
|
),
|
|
''::tsvector
|
|
) || coalesce(
|
|
setweight(to_tsvector(array_to_string(names, ' ')), 'C'),
|
|
''::tsvector
|
|
),
|
|
(SELECT * FROM q)
|
|
) AS rank,
|
|
ts_headline(
|
|
coalesce(
|
|
verifications.verified_line,
|
|
buscribe_transcriptions.transcription_line
|
|
),
|
|
(SELECT * FROM q),
|
|
'StartSel=''<span class=\"highlight\">'', StopSel=</span>'
|
|
) AS highlighted_text
|
|
-- null AS transcription_json
|
|
FROM buscribe_transcriptions
|
|
INNER JOIN (
|
|
SELECT *,
|
|
coalesce(relevant_verified.line, relevant_speakers.line) AS cline,
|
|
coalesce(
|
|
relevant_verified.verifier,
|
|
relevant_speakers.verifier
|
|
) AS cverifier
|
|
FROM (
|
|
SELECT *
|
|
FROM buscribe_verified_lines
|
|
WHERE line IN (
|
|
SELECT id
|
|
FROM relevant_lines
|
|
)
|
|
) AS relevant_verified
|
|
FULL OUTER JOIN (
|
|
SELECT line,
|
|
verifier,
|
|
ARRAY(
|
|
SELECT speaker_name
|
|
FROM buscribe_line_speakers AS inner_speakers
|
|
WHERE inner_speakers.line = buscribe_line_speakers.line
|
|
AND inner_speakers.verifier = buscribe_line_speakers.verifier
|
|
) AS names
|
|
FROM buscribe_line_speakers
|
|
WHERE line IN (
|
|
SELECT id
|
|
FROM relevant_lines
|
|
)
|
|
) AS relevant_speakers ON relevant_verified.line = relevant_speakers.line
|
|
AND relevant_speakers.verifier = relevant_verified.verifier
|
|
) AS verifications ON id = verifications.cline
|
|
WHERE start_time >= coalesce(%(start_time)s, '-infinity'::timestamp)
|
|
AND end_time <= coalesce(%(end_time)s, 'infinity'::timestamp)
|
|
)
|
|
)
|
|
ORDER BY
|
|
--rank DESC,
|
|
start_time;
|
|
"""
|
|
|
|
db_results = database.query(db_conn, query,
|
|
start_time=start_time if start_time is not None else '-infinity',
|
|
end_time=end_time if end_time is not None else 'infinity',
|
|
ts_query=ts_query)
|
|
|
|
# Number of messages
|
|
n_m = db_results.rowcount
|
|
|
|
# Get duration for message frequency calculation
|
|
bus_duration = database.query(db_conn,
|
|
"""
|
|
--sql
|
|
SELECT EXTRACT(EPOCH FROM max(end_time) - min(start_time)) AS bus_duration FROM buscribe_transcriptions;
|
|
"""
|
|
).fetchone().bus_duration
|
|
|
|
# Priors saying that an interesting event (p < 0.01) is when three messages per minute are posted, giving the
|
|
# equivalent average 13 messages per minute
|
|
a_p = 13.0
|
|
b_p = 60.0
|
|
|
|
# 0.02th quantile difference (cf. chat)
|
|
p = 0.02
|
|
l = b_p + bus_duration
|
|
a = a_p + n_m
|
|
|
|
# Lomax distribution is posterior predictive for exponential
|
|
message_duration_diff = l * ((1 - p)**-(1/a) - 1)
|
|
|
|
current_result = Result()
|
|
results = []
|
|
# print(message_duration_diff)
|
|
for transcript_line in db_results:
|
|
|
|
# Current result set is new
|
|
if not current_result.transcript:
|
|
current_result.transcript.append(transcript_line)
|
|
# New message is within window
|
|
elif (transcript_line.start_time - current_result.transcript[-1].end_time).total_seconds() <= message_duration_diff:
|
|
# print((transcript_line.start_time -
|
|
# current_result.transcript[-1].end_time).total_seconds())
|
|
current_result.transcript.append(transcript_line)
|
|
# New message is outside window
|
|
else:
|
|
# Always save the run (cf. chat; we save all lines)
|
|
results.append(current_result)
|
|
# Start new run
|
|
current_result = Result(transcript=[transcript_line])
|
|
|
|
results.append(current_result)
|
|
|
|
return results
|
|
|
|
|
|
def get_vst(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
|
|
query = """
|
|
--sql
|
|
SELECT
|
|
video_ranges[1].start AS start_time,
|
|
video_ranges[1].end AS end_time,
|
|
video_title AS title,
|
|
video_id AS id,
|
|
ts_rank_cd(
|
|
setweight(
|
|
to_tsvector('english'::regconfig, video_title),
|
|
'C'),
|
|
websearch_to_tsquery(%(ts_query)s)
|
|
) AS rank
|
|
FROM events
|
|
WHERE
|
|
video_ranges[1].start >= %(start_time)s AND
|
|
video_ranges[1].end <= %(end_time)s AND
|
|
array_length(video_ranges, 1) = 1 AND -- Only handling single-cut videos for now
|
|
to_tsvector('english'::regconfig, video_title) @@ websearch_to_tsquery(%(ts_query)s)
|
|
ORDER BY
|
|
--rank DESC,
|
|
start_time
|
|
LIMIT 100; -- Hard limit on result number
|
|
"""
|
|
|
|
results = database.query(db_conn, query,
|
|
start_time=start_time if start_time is not None else '-infinity',
|
|
end_time=end_time if end_time is not None else 'infinity',
|
|
ts_query=ts_query)
|
|
|
|
return [Result(vst=result) for result in results]
|
|
|
|
|
|
def get_chat(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
|
|
query = """
|
|
--sql
|
|
SELECT * FROM (
|
|
SELECT
|
|
pub_time,
|
|
content->'tags'->>'display-name' AS name,
|
|
content->'params'->>1 AS content,
|
|
ts_rank_cd(
|
|
setweight(
|
|
to_tsvector('english'::regconfig, content->'params'->>1),
|
|
'A'),
|
|
websearch_to_tsquery(%(ts_query)s)
|
|
) AS rank
|
|
FROM chat
|
|
WHERE
|
|
pub_time >= %(start_time)s AND
|
|
pub_time <= %(end_time)s AND
|
|
to_tsvector('english'::regconfig, content->'params'->>1) @@ websearch_to_tsquery(%(ts_query)s)
|
|
--ORDER BY
|
|
-- rank DESC
|
|
--LIMIT 100 -- Hard limit on result number
|
|
) AS chat ORDER BY pub_time; -- Re-sort by time for merging
|
|
"""
|
|
|
|
db_results = database.query(db_conn, query,
|
|
start_time=start_time if start_time is not None else '-infinity',
|
|
end_time=end_time if end_time is not None else 'infinity',
|
|
ts_query=ts_query)
|
|
|
|
# Number of messages
|
|
n_m = db_results.rowcount
|
|
|
|
# Get duration for message frequency calculation
|
|
bus_duration = database.query(db_conn,
|
|
"""
|
|
--sql
|
|
SELECT EXTRACT(EPOCH FROM max(pub_time) - min(pub_time)) AS bus_duration FROM chat;
|
|
"""
|
|
).fetchone().bus_duration
|
|
|
|
# Priors saying that an interesting event (p < 0.01) is when three messages per minute are posted, giving the
|
|
# equivalent average 13 messages per minute
|
|
a_p = 13.0
|
|
b_p = 60.0
|
|
|
|
# 0.01th quantile difference
|
|
p = 0.01
|
|
l = b_p + bus_duration
|
|
a = a_p + n_m
|
|
# Lomax distribution is posterior predictive for exponential
|
|
message_duration_diff = l * ((1 - p)**-(1/a) - 1)
|
|
|
|
current_result = Result()
|
|
results = []
|
|
for chat_line in db_results:
|
|
# Current result set is new
|
|
if not current_result.chat:
|
|
current_result.chat.append(chat_line)
|
|
# New message is within window
|
|
elif (chat_line.pub_time - current_result.chat[-1].pub_time).total_seconds() <= message_duration_diff:
|
|
current_result.chat.append(chat_line)
|
|
# New message is outside window
|
|
else:
|
|
# Current run has more than one message: save it
|
|
if len(current_result.chat) > 1:
|
|
results.append(current_result)
|
|
# Start new run
|
|
current_result = Result(chat=[chat_line])
|
|
|
|
if len(current_result.chat) > 1:
|
|
results.append(current_result)
|
|
|
|
return results
|
|
|
|
def load_results_data(db_conn, results):
|
|
"""
|
|
Replace chat and transcript with all entries in result's timeframe.
|
|
"""
|
|
|
|
# ggroup = Pool(size=30)
|
|
|
|
# results = ggroup.map(lambda result: load_result_data(db_manager, result), results)
|
|
|
|
result_timespans = [(i, result.start_time, result.end_time) for (i, result) in enumerate(results)]
|
|
|
|
# Clear lists so we can later insert new lines
|
|
for result in results:
|
|
result.chat = []
|
|
result.transcript = []
|
|
|
|
|
|
cur = db_conn.cursor()
|
|
|
|
execute_values(cur,
|
|
"""
|
|
--sql
|
|
WITH timespans (id, start_time, end_time) AS (VALUES %s)
|
|
SELECT
|
|
timespans.id,
|
|
pub_time,
|
|
content->'tags'->>'display-name' AS name,
|
|
content->'params'->>1 AS content FROM timespans JOIN chat ON (pub_time BETWEEN start_time AND end_time);
|
|
""",
|
|
result_timespans
|
|
)
|
|
|
|
for chat_line in cur:
|
|
results[chat_line.id].chat.append(chat_line)
|
|
|
|
execute_values(cur,
|
|
"""
|
|
--sql
|
|
WITH timespans (id, start_time, end_time) AS (VALUES %s)
|
|
SELECT timespans.id,
|
|
ARRAY(
|
|
SELECT row_to_json(transcriptions)
|
|
FROM (
|
|
SELECT buscribe_transcriptions.start_time,
|
|
buscribe_transcriptions.end_time,
|
|
ARRAY(
|
|
SELECT speaker_name
|
|
FROM buscribe_line_inferred_speakers
|
|
WHERE buscribe_line_inferred_speakers.line = buscribe_transcriptions.id
|
|
) as names,
|
|
buscribe_transcriptions.transcription_line
|
|
FROM buscribe_transcriptions
|
|
WHERE buscribe_transcriptions.start_time >= timespans.start_time
|
|
AND buscribe_transcriptions.start_time <= timespans.end_time
|
|
AND buscribe_transcriptions.end_time >= timespans.start_time
|
|
AND buscribe_transcriptions.end_time <= timespans.end_time
|
|
) AS transcriptions
|
|
) AS transcriptions
|
|
FROM timespans;
|
|
""",
|
|
result_timespans
|
|
)
|
|
|
|
TranscriptRecord = namedtuple("TranscriptRecord", cur.fetchone().transcriptions[0].keys())
|
|
cur.scroll(-1)
|
|
|
|
for result in cur:
|
|
results[result.id].transcript.extend([TranscriptRecord(**transcription) for transcription in result.transcriptions])
|
|
|
|
return results
|
|
|
|
|
|
def merge_results(transcript: list[Result], vst: list[Result], chat: list[Result], limit: int, offset = 0):
|
|
"""
|
|
Merge different types of results in order of importance.
|
|
|
|
First, merge anything that overlaps with a VST clip into the clip's result.
|
|
|
|
Second, merge any chat the overlaps with transcripts into the transcript.
|
|
|
|
Finally, append remaining chats.
|
|
"""
|
|
|
|
transcript_i = 0
|
|
chat_i = 0
|
|
|
|
# Merge transcript and chat into vst
|
|
for vst_result in vst:
|
|
while transcript_i < len(transcript) and transcript[transcript_i].start_time < vst_result.end_time:
|
|
if overlap(vst_result, transcript[transcript_i]):
|
|
vst_result.transcript.extend(
|
|
transcript.pop(transcript_i).transcript)
|
|
else:
|
|
transcript_i += 1
|
|
|
|
while chat_i < len(chat) and chat[chat_i].start_time < vst_result.end_time:
|
|
if overlap(vst_result, chat[chat_i]):
|
|
vst_result.chat.extend(chat.pop(chat_i).chat)
|
|
else:
|
|
chat_i += 1
|
|
|
|
chat_i = 0
|
|
|
|
# Merge chat into transcript
|
|
for transcript_result in transcript:
|
|
while chat_i < len(chat) and chat[chat_i].start_time < transcript_result.end_time:
|
|
if overlap(transcript_result, chat[chat_i]):
|
|
transcript_result.chat.extend(chat.pop(chat_i).chat)
|
|
else:
|
|
chat_i += 1
|
|
|
|
merged = transcript + vst + chat
|
|
merged.sort(key=lambda result: result.start_time)
|
|
merged.sort(key=lambda result: result.weight, reverse=True)
|
|
return merged[offset:min((offset + limit), len(merged))]
|
|
|
|
|
|
def overlap(result_a, result_b):
|
|
"""
|
|
A |---------|
|
|
B |---|
|
|
|
|
or
|
|
|
|
A |-----|
|
|
B |---|
|
|
|
|
or
|
|
|
|
A |---|
|
|
B |---|
|
|
"""
|
|
return result_b.start_time >= result_a.start_time and result_b.start_time <= result_a.end_time or \
|
|
result_b.end_time >= result_a.start_time and result_b.end_time <= result_a.end_time
|