diff --git a/bus_analyzer/bus_analyzer/extract.py b/bus_analyzer/bus_analyzer/extract.py index 9da601a..cb5a20f 100644 --- a/bus_analyzer/bus_analyzer/extract.py +++ b/bus_analyzer/bus_analyzer/extract.py @@ -129,7 +129,7 @@ def normalize(image): return image -def recognize_digit(prototypes, image, blank_is_zero=False): +def recognize_digit(prototypes, image, blank_is_zero=False, mask=None): """Takes a normalized digit image and returns (detected number, score, all_scores) where score is between 0 and 1. Higher numbers are more certain the number is correct. all_scores is for debugging. @@ -143,7 +143,7 @@ def recognize_digit(prototypes, image, blank_is_zero=False): except ValueError: return None scores = sorted([ - (compare_images(prototype, image), maybeFloat(n)) + (compare_images(prototype, image, mask), maybeFloat(n)) for n, prototype in prototypes.items() ], reverse=True) best_score, number = scores[0] @@ -153,10 +153,16 @@ def recognize_digit(prototypes, image, blank_is_zero=False): return number, best_score - runner_up_score, scores -def compare_images(prototype, image): +def compare_images(prototype, image, mask=None): """Takes a normalized digit image and a prototype image, and returns a score for how close the image is to looking like that prototype.""" - pairs = zip(image.getdata(), prototype.getdata()) + image = list(image.getdata()) + prototype = list(prototype.getdata()) + if mask: + mask = list(mask.getdata()) + image = [mask[i] / 255 * image[i] for i in range(len(mask))] + prototype = [mask[i] / 255 * prototype[i] for i in range(len(mask))] + pairs = zip(image, prototype) error_squared = sum((a - b)**2 for a, b in pairs) MAX_ERROR_SQUARED = 255**2 * len(pairs) return 1 - (float(error_squared) / MAX_ERROR_SQUARED)**0.5 @@ -195,8 +201,9 @@ def recognize_odometer(prototypes, frame): """ odo = frame.crop(AREA_COORDS["odo"]) digits = extract_digits(odo, "odo") + mask = Image.open(prototypes['mask'][0]) digits = [ - recognize_digit(prototypes["odo-digits"], digit) for digit in digits[:-1] + recognize_digit(prototypes["odo-digits"], digit, mask=mask) for digit in digits[:-1] ] + [ recognize_digit(prototypes["odo-last-digit"], digits[-1]) ] @@ -213,8 +220,9 @@ def recognize_odometer(prototypes, frame): def recognize_clock(prototypes, frame): clock = frame.crop(AREA_COORDS["clock"]) digits = extract_digits(clock, "clock") + mask = Image.open(prototypes['mask'][0]) digits = [ - recognize_digit(prototypes["odo-digits"], digit, i == 0) for i, digit in enumerate(digits) + recognize_digit(prototypes["odo-digits"], digit, i == 0, mask=mask) for i, digit in enumerate(digits) ] if any(digit is None for digit, _, _ in digits): # If any digit is None, report whole thing as None diff --git a/bus_analyzer/prototypes/mask/mask.png b/bus_analyzer/prototypes/mask/mask.png new file mode 100644 index 0000000..6bde809 Binary files /dev/null and b/bus_analyzer/prototypes/mask/mask.png differ