You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

428 lines
15 KiB
Python

2 years ago
import flask as flask
from datetime import datetime
from common import database
app = flask.Flask('escher')
class Result:
def __init__(self, transcript=None, vst=None, chat=None):
self.transcript = transcript if transcript else []
self.vst = vst
self.chat = chat if chat else []
def __repr__(self) -> str:
return f'Result(transcript={self.transcript}, vst={self.vst}, chat={self.chat})'
@property
def start_time(self):
"""Compute the start time of the whole result."""
start_times = [self.vst.start_time] if self.vst else []
if self.transcript:
start_times.append(self.transcript[0].start_time)
if self.chat:
start_times.append(self.chat[0].pub_time)
return min(start_times)
@property
def end_time(self):
"""Compute the start time of the whole result."""
end_times = [self.vst.end_time] if self.vst else []
if self.transcript:
end_times.append(self.transcript[-1].end_time)
if self.chat:
end_times.append(self.chat[-1].pub_time)
return max(end_times)
@property
def weight(self):
return sum([cl.rank for cl in self.chat], 0) + \
10 * sum([tl.rank for tl in self.transcript], 0) + \
20 * (self.vst.rank if self.vst else 0)
def get_transcript(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
query = """
--sql
WITH q AS (
SELECT convert_query(%(ts_query)s)
),
relevant_lines AS (
(
SELECT id
FROM buscribe_transcriptions
WHERE to_tsvector('english', transcription_line) @@ (SELECT * FROM q)
)
UNION
(
SELECT line
FROM buscribe_verified_lines
WHERE to_tsvector('english', verified_line) @@ (SELECT * FROM q)
)
UNION
(
SELECT line
FROM buscribe_line_speakers
WHERE to_tsvector('english', speaker_name) @@ (SELECT * FROM q)
)
UNION
(
SELECT line
FROM buscribe_line_inferred_speakers
WHERE to_tsvector('english', speaker_name) @@ (SELECT * FROM q)
)
)
(
(
SELECT id,
start_time,
end_time,
null AS verifier,
names,
transcription_line,
ts_rank_cd(
coalesce(
to_tsvector('english', transcription_line),
''::tsvector
) || coalesce(
to_tsvector(array_to_string(names, ' ')),
''::tsvector
),
(SELECT * FROM q)
) AS rank,
ts_headline(
transcription_line,
(SELECT * FROM q),
'StartSel=''<span class=\"highlight\">'', StopSel=</span>'
) AS highlighted_text
-- transcription_json
FROM buscribe_transcriptions
LEFT OUTER JOIN (
SELECT line,
ARRAY(
SELECT speaker_name
FROM buscribe_line_inferred_speakers AS inner_speakers
WHERE inner_speakers.line = buscribe_line_inferred_speakers.line
) AS names
FROM buscribe_line_inferred_speakers
) AS inferred_speakers ON id = inferred_speakers.line
WHERE id IN (
SELECT id
FROM relevant_lines
)
AND start_time >= coalesce(%(start_time)s, '-infinity'::timestamp)
AND end_time <= coalesce(%(end_time)s, 'infinity'::timestamp)
)
UNION
(
SELECT buscribe_transcriptions.id AS id,
start_time,
end_time,
cverifier AS verifier,
names,
coalesce(
verifications.verified_line,
buscribe_transcriptions.transcription_line
) AS transcription_line,
ts_rank_cd(
coalesce(
setweight(to_tsvector('english', verified_line), 'C'),
to_tsvector(
'english',
buscribe_transcriptions.transcription_line
),
''::tsvector
) || coalesce(
setweight(to_tsvector(array_to_string(names, ' ')), 'C'),
''::tsvector
),
(SELECT * FROM q)
) AS rank,
ts_headline(
coalesce(
verifications.verified_line,
buscribe_transcriptions.transcription_line
),
(SELECT * FROM q),
'StartSel=''<span class=\"highlight\">'', StopSel=</span>'
) AS highlighted_text
-- null AS transcription_json
FROM buscribe_transcriptions
INNER JOIN (
SELECT *,
coalesce(relevant_verified.line, relevant_speakers.line) AS cline,
coalesce(
relevant_verified.verifier,
relevant_speakers.verifier
) AS cverifier
FROM (
SELECT *
FROM buscribe_verified_lines
WHERE line IN (
SELECT id
FROM relevant_lines
)
) AS relevant_verified
FULL OUTER JOIN (
SELECT line,
verifier,
ARRAY(
SELECT speaker_name
FROM buscribe_line_speakers AS inner_speakers
WHERE inner_speakers.line = buscribe_line_speakers.line
AND inner_speakers.verifier = buscribe_line_speakers.verifier
) AS names
FROM buscribe_line_speakers
WHERE line IN (
SELECT id
FROM relevant_lines
)
) AS relevant_speakers ON relevant_verified.line = relevant_speakers.line
AND relevant_speakers.verifier = relevant_verified.verifier
) AS verifications ON id = verifications.cline
WHERE start_time >= coalesce(%(start_time)s, '-infinity'::timestamp)
AND end_time <= coalesce(%(end_time)s, 'infinity'::timestamp)
)
)
ORDER BY
--rank DESC,
start_time;
"""
db_results = database.query(db_conn, query,
start_time=start_time if start_time is not None else '-infinity',
end_time=end_time if end_time is not None else 'infinity',
ts_query=ts_query)
# Number of messages
n_m = db_results.rowcount
# Get duration for message frequency calculation
bus_duration = database.query(db_conn,
"""
--sql
SELECT EXTRACT(EPOCH FROM max(end_time) - min(start_time)) AS bus_duration FROM buscribe_transcriptions;
"""
).fetchone().bus_duration
# Priors saying that an interesting event (p < 0.01) is when three messages per minute are posted, giving the
# equivalent average 13 messages per minute
a_p = 13.0
b_p = 60.0
# 0.02th quantile difference (cf. chat)
p = 0.02
l = b_p + bus_duration
a = a_p + n_m
# Lomax distribution is posterior predictive for exponential
message_duration_diff = l * ((1 - p)**-(1/a) - 1)
current_result = Result()
results = []
print(message_duration_diff)
for transcript_line in db_results:
# Current result set is new
if not current_result.transcript:
current_result.transcript.append(transcript_line)
# New message is within window
elif (transcript_line.start_time - current_result.transcript[-1].end_time).total_seconds() <= message_duration_diff:
print((transcript_line.start_time -
current_result.transcript[-1].end_time).total_seconds())
current_result.transcript.append(transcript_line)
# New message is outside window
else:
# Always save the run (cf. chat; we save all lines)
results.append(current_result)
# Start new run
current_result = Result(transcript=[transcript_line])
results.append(current_result)
return results
def get_vst(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
query = """
--sql
SELECT
video_ranges[1].start AS start_time,
video_ranges[1].end AS end_time,
video_title AS title,
video_id AS id,
ts_rank_cd(
setweight(
to_tsvector('english'::regconfig, video_title),
'C'),
websearch_to_tsquery(%(ts_query)s)
) AS rank
FROM events
WHERE
video_ranges[1].start >= %(start_time)s AND
video_ranges[1].end <= %(end_time)s AND
array_length(video_ranges, 1) = 1 AND -- Only handling single-cut videos for now
to_tsvector('english'::regconfig, video_title) @@ websearch_to_tsquery(%(ts_query)s)
ORDER BY
--rank DESC,
start_time
LIMIT 100; -- Hard limit on result number
"""
results = database.query(db_conn, query,
start_time=start_time if start_time is not None else '-infinity',
end_time=end_time if end_time is not None else 'infinity',
ts_query=ts_query)
return [Result(vst=result) for result in results]
def get_chat(db_conn, ts_query, start_time="-infinity", end_time="infinity"):
query = """
--sql
SELECT * FROM (
SELECT
pub_time,
content->'tags'->>'display-name' AS name,
content->'params'->>1 AS content,
ts_rank_cd(
setweight(
to_tsvector('english'::regconfig, content->'params'->>1),
'A'),
websearch_to_tsquery(%(ts_query)s)
) AS rank
FROM chat
WHERE
pub_time >= %(start_time)s AND
pub_time <= %(end_time)s AND
to_tsvector('english'::regconfig, content->'params'->>1) @@ websearch_to_tsquery(%(ts_query)s)
--ORDER BY
-- rank DESC
--LIMIT 100 -- Hard limit on result number
) AS chat ORDER BY pub_time; -- Re-sort by time for merging
"""
db_results = database.query(db_conn, query,
start_time=start_time if start_time is not None else '-infinity',
end_time=end_time if end_time is not None else 'infinity',
ts_query=ts_query)
# Number of messages
n_m = db_results.rowcount
# Get duration for message frequency calculation
bus_duration = database.query(db_conn,
"""
--sql
SELECT EXTRACT(EPOCH FROM max(pub_time) - min(pub_time)) AS bus_duration FROM chat;
"""
).fetchone().bus_duration
# Priors saying that an interesting event (p < 0.01) is when three messages per minute are posted, giving the
# equivalent average 13 messages per minute
a_p = 13.0
b_p = 60.0
# 0.01th quantile difference
p = 0.01
l = b_p + bus_duration
a = a_p + n_m
# Lomax distribution is posterior predictive for exponential
message_duration_diff = l * ((1 - p)**-(1/a) - 1)
current_result = Result()
results = []
for chat_line in db_results:
# Current result set is new
if not current_result.chat:
current_result.chat.append(chat_line)
# New message is within window
elif (chat_line.pub_time - current_result.chat[-1].pub_time).total_seconds() <= message_duration_diff:
current_result.chat.append(chat_line)
# New message is outside window
else:
# Current run has more than one message: save it
if len(current_result.chat) > 1:
results.append(current_result)
# Start new run
current_result = Result(chat=[chat_line])
if len(current_result.chat) > 1:
results.append(current_result)
return results
def load_result_data(result):
pass
def merge_results(transcript: list[Result], vst: list[Result], chat: list[Result]):
"""
Merge different types of results in order of importance.
First, merge anything that overlaps with a VST clip into the clip's result.
Second, merge any chat the overlaps with transcripts into the transcript.
Finally, append remaining chats.
"""
transcript_i = 0
chat_i = 0
# Merge transcript and chat into vst
for vst_result in vst:
while transcript_i < len(transcript) and transcript[transcript_i].start_time < vst_result.end_time:
if overlap(vst_result, transcript[transcript_i]):
vst_result.transcript.extend(
transcript.pop(transcript_i).transcript)
else:
transcript_i += 1
# print(vst_result)
while chat_i < len(chat) and chat[chat_i].start_time < vst_result.end_time:
# print(vst_result)
if overlap(vst_result, chat[chat_i]):
vst_result.chat.extend(chat.pop(chat_i).chat)
# print(vst_result)
else:
chat_i += 1
chat_i = 0
# Merge chat into transcript
for transcript_result in transcript:
while chat_i < len(chat) and chat[chat_i].start_time < transcript_result.end_time:
if overlap(transcript_result, chat[chat_i]):
transcript_result.chat.extend(chat.pop(chat_i).chat)
else:
chat_i += 1
merged = transcript + vst + chat
merged.sort(key=lambda result: result.start_time)
merged.sort(key=lambda result: result.weight, reverse=True)
return merged
def overlap(result_a, result_b):
"""
A |---------|
B |---|
or
A |-----|
B |---|
or
A |---|
B |---|
"""
return result_b.start_time >= result_a.start_time and result_b.start_time <= result_a.end_time or \
result_b.end_time >= result_a.start_time and result_b.end_time <= result_a.end_time