Allow for masking digits

chrusher/bus_synthesizer
Christopher Usher 2 weeks ago
parent 6c97654da3
commit 656be6c292

@ -129,7 +129,7 @@ def normalize(image):
return image
def recognize_digit(prototypes, image, blank_is_zero=False):
def recognize_digit(prototypes, image, blank_is_zero=False, mask=None):
"""Takes a normalized digit image and returns (detected number, score, all_scores)
where score is between 0 and 1. Higher numbers are more certain the number is correct.
all_scores is for debugging.
@ -143,7 +143,7 @@ def recognize_digit(prototypes, image, blank_is_zero=False):
except ValueError:
return None
scores = sorted([
(compare_images(prototype, image), maybeFloat(n))
(compare_images(prototype, image, mask), maybeFloat(n))
for n, prototype in prototypes.items()
], reverse=True)
best_score, number = scores[0]
@ -153,10 +153,16 @@ def recognize_digit(prototypes, image, blank_is_zero=False):
return number, best_score - runner_up_score, scores
def compare_images(prototype, image):
def compare_images(prototype, image, mask=None):
"""Takes a normalized digit image and a prototype image, and returns a score
for how close the image is to looking like that prototype."""
pairs = zip(image.getdata(), prototype.getdata())
image = list(image.getdata())
prototype = list(prototype.getdata())
if mask:
mask = list(mask.getdata())
image = [mask[i] / 255 * image[i] for i in range(len(mask))]
prototype = [mask[i] / 255 * prototype[i] for i in range(len(mask))]
pairs = zip(image, prototype)
error_squared = sum((a - b)**2 for a, b in pairs)
MAX_ERROR_SQUARED = 255**2 * len(pairs)
return 1 - (float(error_squared) / MAX_ERROR_SQUARED)**0.5
@ -195,8 +201,9 @@ def recognize_odometer(prototypes, frame):
"""
odo = frame.crop(AREA_COORDS["odo"])
digits = extract_digits(odo, "odo")
mask = Image.open(prototypes['mask'][0])
digits = [
recognize_digit(prototypes["odo-digits"], digit) for digit in digits[:-1]
recognize_digit(prototypes["odo-digits"], digit, mask=mask) for digit in digits[:-1]
] + [
recognize_digit(prototypes["odo-last-digit"], digits[-1])
]
@ -213,8 +220,9 @@ def recognize_odometer(prototypes, frame):
def recognize_clock(prototypes, frame):
clock = frame.crop(AREA_COORDS["clock"])
digits = extract_digits(clock, "clock")
mask = Image.open(prototypes['mask'][0])
digits = [
recognize_digit(prototypes["odo-digits"], digit, i == 0) for i, digit in enumerate(digits)
recognize_digit(prototypes["odo-digits"], digit, i == 0, mask=mask) for i, digit in enumerate(digits)
]
if any(digit is None for digit, _, _ in digits):
# If any digit is None, report whole thing as None

Binary file not shown.

After

Width:  |  Height:  |  Size: 94 B

Loading…
Cancel
Save