From be77f4ea146207fba66e1794e4cb2f5ec1d2e2ab Mon Sep 17 00:00:00 2001 From: Mike Lang Date: Tue, 14 Nov 2023 16:38:07 +1100 Subject: [PATCH] bus_analyzer: Also record the clock --- bus_analyzer/bus_analyzer/extract.py | 88 ++++++++++++++++++++-------- bus_analyzer/bus_analyzer/main.py | 18 +++--- postgres/setup.sh | 2 + 3 files changed, 77 insertions(+), 31 deletions(-) diff --git a/bus_analyzer/bus_analyzer/extract.py b/bus_analyzer/bus_analyzer/extract.py index d719b1c..976ca57 100644 --- a/bus_analyzer/bus_analyzer/extract.py +++ b/bus_analyzer/bus_analyzer/extract.py @@ -3,7 +3,7 @@ import os from io import BytesIO import argh -from PIL import Image, ImageStat +from PIL import Image import common from common.segments import extract_frame, parse_segment_path @@ -11,9 +11,21 @@ from common.segments import extract_frame, parse_segment_path # DB2023 buscam # bounding box (left x, top y, right x, bottom y) of the area the odometer can be in -ODO_COORDS = 1121, 820, 1270, 897 +AREA_COORDS = { + "odo": (1121, 820, 1270, 897), + "clock": (1685, 819, 1804, 877), +} # starting x coord of each digit within the odo box -DIGIT_X_COORDS = [0, 28, 56, 84, 123] +DIGIT_X_COORDS = { + "odo": [0, 28, 56, 84, 123], + "clock": [0, 27, 66, 93], +} +# value of each digit +DIGIT_BASES = { + "odo": [1000, 100, 10, 1, 0.1], + "clock": [600, 60, 10, 1], +} + DIGIT_WIDTH = 26 # Most digits we only care about the actual character height DIGIT_HEIGHT = 26 @@ -31,7 +43,7 @@ cli = argh.EntryPoint() @cli @argh.arg("paths", nargs="+") -def to_digits(output_dir, paths, box_only=False): +def to_digits(output_dir, paths, box_only=False, type="odo"): """Extracts each digit and saves to a file. Useful for testing or building prototypes.""" if not os.path.exists(output_dir): os.mkdir(output_dir) @@ -39,17 +51,12 @@ def to_digits(output_dir, paths, box_only=False): name = os.path.splitext(os.path.basename(path))[0] image = Image.open(path) if not box_only: - image = extract_odo(image) - for i, digit in enumerate(extract_digits(image)): + image = image.crop(AREA_COORDS[type]) + for i, digit in enumerate(extract_digits(image, type)): output_path = os.path.join(output_dir, "{}-digit{}.png".format(name, i)) digit.save(output_path) -def extract_odo(image): - """Takes a full frame, and returns the odo box""" - return image.crop(ODO_COORDS) - - def get_brightest_region(image, xs, height): """For given image, return the sub-image of given height with the brightest values for all the pixels at given x positions within the row.""" @@ -64,26 +71,30 @@ def get_brightest_region(image, xs, height): return image.crop((0, start_at, image.width, start_at + height)) -def extract_digits(image, include_last=True): +def extract_digits(image, type): """Takes an odo box, and returns a list of digit images""" + main_digit_coords = DIGIT_X_COORDS[type] + if type == "odo": + main_digit_coords = main_digit_coords[:-1] + # convert to greyscale image = image.convert(mode='L') # Find main digits y position digit_xs = [ x + dx - for x in DIGIT_X_COORDS + for x in DIGIT_X_COORDS[type] for dx in range(DIGIT_WIDTH) ] main_digits = get_brightest_region(image, digit_xs, DIGIT_HEIGHT) digits = [] - for i, x in enumerate(DIGIT_X_COORDS[:-1]): + for i, x in enumerate(main_digit_coords): digit = main_digits.crop((x, 0, x + DIGIT_WIDTH, main_digits.height)) digits.append(normalize_digit(digit)) - if include_last: - x = DIGIT_X_COORDS[-1] + if type == "odo": + x = DIGIT_X_COORDS["odo"][-1] last_digit = get_brightest_region(image, range(x, x + DIGIT_WIDTH), LAST_DIGIT_HEIGHT) last_digit = last_digit.crop((x, 0, x + DIGIT_WIDTH, last_digit.height)) digits.append(normalize_digit(last_digit)) @@ -104,13 +115,15 @@ def normalize_digit(digit): return digit -def recognize_digit(prototypes, image): +def recognize_digit(prototypes, image, blank_is_zero=False): """Takes a normalized digit image and returns (detected number, score, all_scores) where score is between 0 and 1. Higher numbers are more certain the number is correct. all_scores is for debugging. If the most likely detection is NOT a number, None is returned instead. """ def maybeFloat(n): + if n == "blank" and blank_is_zero: + return 0 try: return float(n) except ValueError: @@ -166,8 +179,8 @@ def recognize_odometer(prototypes, frame): where score is between 0 and 1. Higher numbers are more certain the value is correct. digits is for debugging. """ - odo = extract_odo(frame) - digits = extract_digits(odo, include_last=True) + odo = frame.crop(AREA_COORDS["odo"]) + digits = extract_digits(odo, "odo") digits = [ recognize_digit(prototypes["odo-digits"], digit) for digit in digits[:-1] ] + [ @@ -177,7 +190,23 @@ def recognize_odometer(prototypes, frame): if any(digit is None for digit, _, _ in digits): value = None else: - value = sum(digit * 10.**i for i, (digit, _, _) in zip(range(3, -2, -1), digits)) + value = sum(digit * base for base, (digit, _, _) in zip(DIGIT_BASES["odo"], digits)) + # Use average score of digits as frame score + score = sum(score for _, score, _ in digits) / len(digits) + return value, score, digits + + +def recognize_clock(prototypes, frame): + clock = frame.crop(AREA_COORDS["clock"]) + digits = extract_digits(clock, "clock") + digits = [ + recognize_digit(prototypes["odo-digits"], digit, i == 0) for i, digit in enumerate(digits) + ] + # If any digit is None, report whole thing as None. Otherwise, calculate the number. + if any(digit is None for digit, _, _ in digits): + value = None + else: + value = sum(digit * base for base, (digit, _, _) in zip(DIGIT_BASES["clock"], digits)) # Use average score of digits as frame score score = sum(score for _, score, _ in digits) / len(digits) return value, score, digits @@ -185,7 +214,7 @@ def recognize_odometer(prototypes, frame): @cli @argh.arg("frames", nargs="+") -def read_frame(frames, prototypes_path="./prototypes", verbose=False, include_last=False): +def read_frame(frames, prototypes_path="./prototypes", verbose=False): """For testing. Takes any number of frame images (or segments) and prints the odometer reading.""" prototypes = load_prototypes(prototypes_path) for filename in frames: @@ -195,11 +224,18 @@ def read_frame(frames, prototypes_path="./prototypes", verbose=False, include_la frame = Image.open(BytesIO(frame_data)) else: frame = Image.open(filename) + value, score, digits = recognize_odometer(prototypes, frame) if verbose: for guess, score, all_scores in digits: print("Digit = {} with score {}".format(guess, score)) - print("{}: {} with score {}".format(filename, value, score)) + print("{}: odo {} with score {}".format(filename, value, score)) + + value, score, digits = recognize_clock(prototypes, frame) + if verbose: + for guess, score, all_scores in digits: + print("Digit = {} with score {}".format(guess, score)) + print("{}: clock {} with score {}".format(filename, value, score)) @cli @@ -229,10 +265,16 @@ def get_frame(*segments): def extract_segment(prototypes, segment): ODO_SCORE_THRESHOLD = 0.01 + CLOCK_SCORE_THRESHOLD = 0.01 frame_data = b"".join(extract_frame([segment], segment.start)) frame = Image.open(BytesIO(frame_data)) odometer, score, _ = recognize_odometer(prototypes, frame) - return odometer if score >= ODO_SCORE_THRESHOLD else None + if score < ODO_SCORE_THRESHOLD: + odometer = None + clock, score, _ = recognize_clock(prototypes, frame) + if score < CLOCK_SCORE_THRESHOLD: + clock = None + return odometer, clock if __name__ == '__main__': diff --git a/bus_analyzer/bus_analyzer/main.py b/bus_analyzer/bus_analyzer/main.py index 819398c..d5536df 100644 --- a/bus_analyzer/bus_analyzer/main.py +++ b/bus_analyzer/bus_analyzer/main.py @@ -25,8 +25,8 @@ def do_extract_segment(*segment_paths, prototypes_path="./prototypes"): prototypes = load_prototypes(prototypes_path) for segment_path in segment_paths: segment_info = parse_segment_path(segment_path) - odometer = extract_segment(prototypes, segment_info) - print(f"{segment_path} {odometer}") + odometer, clock = extract_segment(prototypes, segment_info) + print(f"{segment_path} {odometer} {clock}") @cli @@ -75,7 +75,7 @@ def compare_segments(dbconnect, base_dir='.', prototypes_path="./prototypes", si for old_odometer, segment in selected: path = os.path.join(base_dir, segment) segment_info = parse_segment_path(path) - odometer = extract_segment(prototypes, segment_info) + odometer, clock = extract_segment(prototypes, segment_info) results.append((segment, old_odometer, odometer)) matching = 0 @@ -112,29 +112,31 @@ def analyze_segment(conn, prototypes, segment_path, check_segment_name=None): assert segment_name == check_segment_name try: - odometer = extract_segment(prototypes, segment_info) + odometer, clock = extract_segment(prototypes, segment_info) except Exception: logging.warning(f"Failed to extract segment {segment_path!r}", exc_info=True) odometer = None error = traceback.format_exc() else: - logging.info(f"Got odometer = {odometer} for segment {segment_path!r}") + logging.info(f"Got odometer = {odometer}, clock = {clock} for segment {segment_path!r}") error = None database.query( conn, """ - INSERT INTO bus_data (channel, timestamp, segment, error, odometer) - VALUES (%(channel)s, %(timestamp)s, %(segment)s, %(error)s, %(odometer)s) + INSERT INTO bus_data (channel, timestamp, segment, error, odometer, clock) + VALUES (%(channel)s, %(timestamp)s, %(segment)s, %(error)s, %(odometer)s, %(clock)s) ON CONFLICT (channel, timestamp, segment) DO UPDATE SET error = %(error)s, - odometer = %(odometer)s + odometer = %(odometer)s, + clock = %(clock)s """, channel=segment_info.channel, timestamp=segment_info.start, segment=segment_name, error=error, odometer=odometer, + clock=clock, ) diff --git a/postgres/setup.sh b/postgres/setup.sh index ab31c08..fca856d 100644 --- a/postgres/setup.sh +++ b/postgres/setup.sh @@ -156,6 +156,7 @@ CREATE TABLE playlists ( -- The "error" column records a free-form human readable message about why a value could not -- be determined. -- The odometer column is in miles. The game shows the odometer to the 1/10th mile precision. +-- The clock is in minutes since 00:00, in 12h time. -- The segment may be NULL, which indicates a manually-inserted value. -- The primary key serves two purposes: -- It provides an index on channel, followed by a range index on timestamp @@ -168,6 +169,7 @@ CREATE TABLE bus_data ( segment TEXT, error TEXT, odometer DOUBLE PRECISION, + clock INTEGER, PRIMARY KEY (channel, timestamp, segment) ); EOSQL