diff --git a/bus_analyzer/bus_analyzer/extract.py b/bus_analyzer/bus_analyzer/extract.py index b5a89e5..08ef4e2 100644 --- a/bus_analyzer/bus_analyzer/extract.py +++ b/bus_analyzer/bus_analyzer/extract.py @@ -1,3 +1,201 @@ -def extract_segment(path): - raise NotImplementedError +import os +from io import BytesIO + +import argh +from PIL import Image, ImageStat + +from common.segments import extract_frame, parse_segment_path + + +# DB2023 buscam +# bounding box (left x, top y, right x, bottom y) of the area the odometer can be in +ODO_COORDS = 1121, 820, 1270, 897 +# starting x coord of each digit within the odo box +DIGIT_X_COORDS = [0, 28, 56, 84, 123] +DIGIT_WIDTH = 26 +# Most digits we only care about the actual character height +DIGIT_HEIGHT = 26 +# But last digit we want the full white-background area as we want to try to match +# based on position also. +LAST_DIGIT_HEIGHT = 38 + +# get back py2 zip behaviour +_zip = zip +def zip(*args): + return list(_zip(*args)) + + +cli = argh.EntryPoint() + +@cli +@argh.arg("paths", nargs="+") +def to_digits(output_dir, paths, box_only=False): + """Extracts each digit and saves to a file. Useful for testing or building prototypes.""" + if not os.path.exists(output_dir): + os.mkdir(output_dir) + for path in paths: + name = os.path.splitext(os.path.basename(path))[0] + image = Image.open(path) + if not box_only: + image = extract_odo(image) + for i, digit in enumerate(extract_digits(image)): + output_path = os.path.join(output_dir, "{}-digit{}.png".format(name, i)) + digit.save(output_path) + + +def extract_odo(image): + """Takes a full frame, and returns the odo box""" + return image.crop(ODO_COORDS) + + +def extract_digits(image, include_last=True): + """Takes an odo box, and returns a list of 9x6 digit images""" + # convert to greyscale + image = image.convert(mode='L') + digits = [] + for i, x in enumerate(DIGIT_X_COORDS): + # last digit is special + is_last = i == len(DIGIT_X_COORDS) - 1 + if is_last and not include_last: + continue + digit = image.crop((x, 0, x + DIGIT_WIDTH, image.height)) + digits.append(normalize_digit(digit, is_last)) + return digits + + +def normalize_digit(digit, is_last=False): + # Calculate total brightness by row + rows = [ + sum(digit.getpixel((x, y)) for x in range(digit.width)) + for y in range(digit.height) + ] + # Find brightest sub-image of DIGIT_HEIGHT rows + h = LAST_DIGIT_HEIGHT if is_last else DIGIT_HEIGHT + start_at = max(range(digit.height - (h-1)), key=lambda y: sum(rows[y:y+h])) + # Cut image to only be that part + digit = digit.crop((0, start_at, digit.width, start_at + h)) + + # Last digit is inverted - by looking for brightest sub-image we've likely found + # the section that has a white background. Now we want to normalize that so it looks like + # other images. + if is_last: + digit = digit.point(lambda v: 255 - v) + + # Expand the range of the image so that the darkest pixel becomes black + # and the lightest becomes white + _min, _max = digit.getextrema() + _range = _max - _min + if _range == 0: + digit = digit.point(lambda v: 128) + else: + digit = digit.point(lambda v: 255 * (v - _min) / _range) + + return digit + + +def recognize_digit(prototypes, image): + """Takes a normalized digit image and returns (detected number, score, all_scores) + where score is between 0 and 1. Higher numbers are more certain the number is correct. + all_scores is for debugging. + """ + scores = sorted([ + (compare_images(prototypes[n], image), n) + for n in range(10) + ], reverse=True) + best_score, number = scores[0] + runner_up_score, _ = scores[1] + # we penalize score if the second best score is high, as this indicates we're uncertain + # which number it is even though both match. + return number, best_score - runner_up_score, scores + + +def compare_images(prototype, image): + """Takes a normalized digit image and a prototype image, and returns a score + for how close the image is to looking like that prototype.""" + pairs = zip(image.getdata(), prototype.getdata()) + error_squared = sum((a - b)**2 for a, b in pairs) + MAX_ERROR_SQUARED = 255**2 * len(pairs) + return 1 - (float(error_squared) / MAX_ERROR_SQUARED)**0.5 + + +def load_prototypes(prototypes_path): + return [ + Image.open(os.path.join(prototypes_path, "{}.png".format(n))) + for n in range(10) + ] + + +@cli +def read_digit(digit, prototypes_path="./odo-digit-prototypes", verbose=False): + """For debugging. Compares an extracted digit image to each prototype and prints scores.""" + prototypes = load_prototypes(prototypes_path) + digit = Image.open(digit) + guess, score, all_scores = recognize_digit(prototypes, digit) + print("Digit = {} with score {}".format(guess, score)) + if verbose: + all_scores.sort(key=lambda x: x[1]) + for s, n in all_scores: + print("{}: {}".format(n, s)) + + +def recognize_odometer(prototypes, frame): + """Takes a full image frame and returns (detected mile value, score, digits) + where score is between 0 and 1. Higher numbers are more certain the value is correct. + digits is for debugging. + """ + odo = extract_odo(frame) + digits = extract_digits(odo, include_last=False) + digits = [recognize_digit(prototypes, digit) for digit in digits] + value = sum(digit * 10**i for i, (digit, _, _) in enumerate(digits[::-1])) + # Use average score of digits as frame score + score = sum(score for _, score, _ in digits) / len(digits) + return value, score, digits + + +@cli +@argh.arg("frames", nargs="+") +def read_frame(frames, prototypes_path="./odo-digit-prototypes", verbose=False, include_last=False): + """For testing. Takes any number of frame images (or segments) and prints the odometer reading.""" + prototypes = load_prototypes(prototypes_path) + for filename in frames: + if filename.endswith(".ts"): + segment = parse_segment_path(filename) + frame_data = b"".join(extract_frame([segment], segment.start)) + frame = Image.open(BytesIO(frame_data)) + else: + frame = Image.open(filename) + value, score, digits = recognize_odometer(prototypes, frame) + if verbose: + for guess, score, all_scores in digits: + print("Digit = {} with score {}".format(guess, score)) + print("{}: {} with score {}".format(filename, value, score)) + + +@cli +def create_prototype(output, *images): + """Create a prototype image by averaging all the given images""" + first = Image.open(images[0]) + data = list(first.getdata()) + for image in images[1:]: + image = Image.open(image) + for i, value in enumerate(image.getdata()): + data[i] += value + data = [v / len(images) for v in data] + first.putdata(data) + first.save(output) + + +def extract_segment(prototypes, segment): + # We haven't observed worse than 0.15 or so in the wild, + # and an all-black screen is identified as "1" with a score of 0.07. + # So as a rough middle ground, require at least 0.1. + ODO_SCORE_THRESHOLD = 0.1 + frame_data = b"".join(extract_frame([segment], segment.start)) + frame = Image.open(BytesIO(frame_data)) + odometer, score, _ = recognize_odometer(prototypes, frame) + return odometer if score >= ODO_SCORE_THRESHOLD else None + + +if __name__ == '__main__': + cli() diff --git a/bus_analyzer/bus_analyzer/main.py b/bus_analyzer/bus_analyzer/main.py index 0f2a509..9d0770c 100644 --- a/bus_analyzer/bus_analyzer/main.py +++ b/bus_analyzer/bus_analyzer/main.py @@ -11,7 +11,7 @@ import gevent.event from common import database from common.segments import parse_segment_path -from .extract import extract_segment +from .extract import extract_segment, load_prototypes cli = argh.EntryPoint() @@ -19,32 +19,39 @@ cli = argh.EntryPoint() @cli @argh.named("extract-segment") -def do_extract_segment(*segment_paths): +def do_extract_segment(*segment_paths, prototypes_path="./odo-digit-prototypes"): """Extract info from individual segments and print them""" + prototypes = load_prototypes(prototypes_path) for segment_path in segment_paths: - odometer = extract_segment(segment_path) + segment_info = parse_segment_path(segment_path) + odometer = extract_segment(prototypes, segment_info) print(f"{segment_path} {odometer}") @cli @argh.named("analyze-segment") -def do_analyze_segment(dbconnect, *segment_paths, base_dir='.'): +def do_analyze_segment(dbconnect, *segment_paths, base_dir='.', prototypes_path="./odo-digit-prototypes"): """Analyze individual segments and write them to the database""" + prototypes = load_prototypes(prototypes_path) dbmanager = database.DBManager(dsn=dbconnect) conn = dbmanager.get_conn() for segment_path in segment_paths: - analyze_segment(conn, segment_path) + analyze_segment(conn, prototypes, segment_path) -def analyze_segment(conn, segment_path, check_segment_name=None): +def analyze_segment(conn, prototypes, segment_path, check_segment_name=None): segment_info = parse_segment_path(segment_path) + if segment_info.type == "temp": + logging.info("Ignoring temp segment {}".format(segment_path)) + return + segment_name = '/'.join(segment_path.split('/')[-4:]) # just keep last 4 path parts if check_segment_name is not None: assert segment_name == check_segment_name try: - odometer = extract_segment(segment_path) + odometer = extract_segment(prototypes, segment_info) except Exception: logging.warning(f"Failed to extract segment {segment_path!r}", exc_info=True) odometer = None @@ -70,7 +77,7 @@ def analyze_segment(conn, segment_path, check_segment_name=None): ) -def analyze_hour(conn, existing_segments, base_dir, channel, quality, hour): +def analyze_hour(conn, prototypes, existing_segments, base_dir, channel, quality, hour): hour_path = os.path.join(base_dir, channel, quality, hour) try: segments = os.listdir(hour_path) @@ -93,7 +100,7 @@ def analyze_hour(conn, existing_segments, base_dir, channel, quality, hour): logging.info("Found {} segments not already existing".format(len(segments_to_do))) for segment_path, segment_name in segments_to_do: - analyze_segment(conn, segment_path, segment_name) + analyze_segment(conn, prototypes, segment_path, segment_name) def parse_hours(s): @@ -105,6 +112,7 @@ def parse_hours(s): @cli @argh.arg("--hours", type=parse_hours, help="If integer, watch the most recent N hours. Otherwise, comma-seperated list of hours.") +@argh.arg("channels", nargs="+") def main( dbconnect, *channels, @@ -113,6 +121,7 @@ def main( hours=2, run_once=False, overwrite=False, + prototypes_path="./odo-digit-prototypes", ): CHECK_INTERVAL = 2 @@ -123,6 +132,8 @@ def main( db_manager = database.DBManager(dsn=dbconnect) conn = db_manager.get_conn() + prototypes = load_prototypes(prototypes_path) + logging.info("Started") while not stopping.is_set(): @@ -162,7 +173,7 @@ def main( for channel in channels: for hour in do_hours: - analyze_hour(conn, existing_segments, base_dir, channel, quality, hour) + analyze_hour(conn, prototypes, existing_segments, base_dir, channel, quality, hour) if run_once: logging.info("Requested to only run once, stopping")