bus_analyzer: Add a testing routine to check analyzer changes

pull/361/head
Mike Lang 1 year ago
parent a4eceea958
commit 5e43446c92

@ -157,7 +157,7 @@ def recognize_odometer(prototypes, frame):
odo = extract_odo(frame) odo = extract_odo(frame)
digits = extract_digits(odo, include_last=False) digits = extract_digits(odo, include_last=False)
digits = [recognize_digit(prototypes, digit) for digit in digits] digits = [recognize_digit(prototypes, digit) for digit in digits]
value = sum(digit * 10**i for i, (digit, _, _) in enumerate(digits[::-1])) value = sum(digit * 10.**i for i, (digit, _, _) in enumerate(digits[::-1]))
# Use average score of digits as frame score # Use average score of digits as frame score
score = sum(score for _, score, _ in digits) / len(digits) score = sum(score for _, score, _ in digits) / len(digits)
return value, score, digits return value, score, digits

@ -2,6 +2,7 @@
import datetime import datetime
import logging import logging
import os import os
import random
import signal import signal
import traceback import traceback
@ -28,6 +29,66 @@ def do_extract_segment(*segment_paths, prototypes_path="./odo-digit-prototypes")
print(f"{segment_path} {odometer}") print(f"{segment_path} {odometer}")
@cli
def compare_segments(dbconnect, base_dir='.', prototypes_path="./odo-digit-prototypes", since=None, until=None, num=100, null_chance=0.25, verbose=False):
"""
Collect some representitive samples from the database and re-runs them to compare to previous results.
num is how many samples to try.
"""
prototypes = load_prototypes(prototypes_path)
dbmanager = database.DBManager(dsn=dbconnect)
conn = dbmanager.get_conn()
where = []
if since:
where.append("timestamp >= %(since)s")
if until:
where.append("timestamp < %(until)s")
if not where:
where = ["true"]
where = " AND ".join(where)
result = database.query(conn, f"""
SELECT odometer, segment
FROM bus_data
WHERE segment IS NOT NULL
AND {where}
""", since=since, until=until)
# To get a wider range of tests, pick at random from all unique odo readings
available = {}
for row in result.fetchall():
available.setdefault(row.odometer, []).append(row.segment)
selected = []
while available and len(selected) < num:
if None in available and random.random() < null_chance:
odometer = None
else:
odometer = random.choice(list(available.keys()))
segments = available[odometer]
random.shuffle(segments)
selected.append((odometer, segments.pop()))
if not segments:
del available[odometer]
results = []
for old_odometer, segment in selected:
path = os.path.join(base_dir, segment)
segment_info = parse_segment_path(path)
odometer = extract_segment(prototypes, segment_info)
results.append((segment, old_odometer, odometer))
matching = 0
for segment, old_odometer, odometer in sorted(results, key=lambda t: t[0]):
match = old_odometer == odometer
if verbose or not match:
print(f"{segment}: {old_odometer} | {odometer}")
if match:
matching += 1
print("{}/{} matched".format(matching, len(selected)))
@cli @cli
@argh.named("analyze-segment") @argh.named("analyze-segment")
def do_analyze_segment(dbconnect, *segment_paths, base_dir='.', prototypes_path="./odo-digit-prototypes"): def do_analyze_segment(dbconnect, *segment_paths, base_dir='.', prototypes_path="./odo-digit-prototypes"):

Loading…
Cancel
Save