diff --git a/bus_analyzer/bus_analyzer/main.py b/bus_analyzer/bus_analyzer/main.py index 37e742c..0fd1c86 100644 --- a/bus_analyzer/bus_analyzer/main.py +++ b/bus_analyzer/bus_analyzer/main.py @@ -8,6 +8,7 @@ import traceback import argh import gevent.event +from gevent.pool import Pool from common import database from common.segments import parse_segment_path, list_segment_files @@ -102,13 +103,12 @@ def do_analyze_segment(dbconnect, *segment_paths, base_dir='.', prototypes_path= """Analyze individual segments and write them to the database""" prototypes = load_prototypes(prototypes_path) dbmanager = database.DBManager(dsn=dbconnect) - conn = dbmanager.get_conn() for segment_path in segment_paths: - analyze_segment(conn, prototypes, segment_path) + analyze_segment(db_manager, prototypes, segment_path) -def analyze_segment(conn, prototypes, segment_path, check_segment_name=None): +def analyze_segment(db_manager, prototypes, segment_path, check_segment_name=None): segment_info = parse_segment_path(segment_path) if segment_info.type == "temp": logging.info("Ignoring temp segment {}".format(segment_path)) @@ -123,11 +123,14 @@ def analyze_segment(conn, prototypes, segment_path, check_segment_name=None): except Exception: logging.warning(f"Failed to extract segment {segment_path!r}", exc_info=True) odometer = None + clock = None + tod = None error = traceback.format_exc() else: logging.info(f"Got odometer = {odometer}, clock = {clock}, time of day = {tod} for segment {segment_path!r}") error = None + conn = db_manager.get_conn() database.query( conn, """ @@ -147,9 +150,10 @@ def analyze_segment(conn, prototypes, segment_path, check_segment_name=None): clock=clock, timeofday=tod, ) + db_manager.put_conn(conn) -def analyze_hour(conn, prototypes, existing_segments, base_dir, channel, quality, hour): +def analyze_hour(db_manager, prototypes, existing_segments, base_dir, channel, quality, hour, concurrency=10): hour_path = os.path.join(base_dir, channel, quality, hour) segments = sorted(list_segment_files(hour_path)) @@ -167,8 +171,12 @@ def analyze_hour(conn, prototypes, existing_segments, base_dir, channel, quality segments_to_do.append((segment_path, segment_name)) logging.info("Found {} segments not already existing".format(len(segments_to_do))) + pool = Pool(concurrency) + workers = [] for segment_path, segment_name in segments_to_do: - analyze_segment(conn, prototypes, segment_path, segment_name) + workers.append(pool.spawn(analyze_segment, db_manager, prototypes, segment_path, segment_name)) + for worker in workers: + worker.get() # re-raise errors def parse_hours(s): @@ -190,6 +198,7 @@ def main( run_once=False, overwrite=False, prototypes_path="./prototypes", + concurrency=10, ): CHECK_INTERVAL = 2 @@ -241,7 +250,7 @@ def main( for channel in channels: for hour in do_hours: - analyze_hour(conn, prototypes, existing_segments, base_dir, channel, quality, hour) + analyze_hour(db_manager, prototypes, existing_segments, base_dir, channel, quality, hour, concurrency=concurrency) if run_once: logging.info("Requested to only run once, stopping")