diff --git a/sheetsync/sheetsync/main.py b/sheetsync/sheetsync/main.py index eb06393..cdfa66e 100644 --- a/sheetsync/sheetsync/main.py +++ b/sheetsync/sheetsync/main.py @@ -276,11 +276,10 @@ class EventsSync(SheetSync): "category", } - def __init__(self, name, middleware, stop, dbmanager, reverse_sync=False, media_dir=None, shifts=None): + def __init__(self, name, middleware, stop, dbmanager, reverse_sync=False, media_dir=None): super().__init__(name, middleware, stop, dbmanager, reverse_sync) self.media_dir = media_dir self.media_downloads = None if media_dir is None else {} - self.shifts = shifts def observe_rows(self, rows): @@ -482,12 +481,18 @@ def main(dbconnect, sync_configs, metrics_port=8005, backdoor_port=0, media_dir= client_secret=creds['client_secret'], refresh_token=creds['refresh_token'], ) - if config["type"] in ("events", "archive"): - middleware_cls = { - "events": SheetsEventsMiddleware, - "archive": SheetsArchiveMiddleware, - }[config["type"]] - middleware = middleware_cls( + if config["type"] == "events": + middleware = SheetsEventsMiddleware( + client, + config["sheet_id"], + config["worksheets"], + common.dateutil.parse(config["bustime_start"]), + config["edit_url"], + shifts, + allocate_ids, + ) + elif config["type"] == "archive": + middleware = SheetsArchiveMiddleware( client, config["sheet_id"], config["worksheets"], @@ -528,8 +533,6 @@ def main(dbconnect, sync_configs, metrics_port=8005, backdoor_port=0, media_dir= "archive": ArchiveSync, }[config["type"]] sync_class_kwargs = {} - if config["type"] == "events": - sync_class_kwargs["shifts"] = shifts if config["type"] == "events" and config.get("download_media", False): sync_class_kwargs["media_dir"] = media_dir sync = sync_class(config["name"], middleware, stop, dbmanager, reverse_sync, **sync_class_kwargs) diff --git a/sheetsync/sheetsync/sheets.py b/sheetsync/sheetsync/sheets.py index c44b8ce..46159e2 100644 --- a/sheetsync/sheetsync/sheets.py +++ b/sheetsync/sheetsync/sheets.py @@ -269,10 +269,11 @@ class SheetsEventsMiddleware(SheetsMiddleware): 'id': 15, } - def __init__(self, client, sheet_id, worksheets, bustime_start, edit_url, allocate_ids=False): + def __init__(self, client, sheet_id, worksheets, bustime_start, edit_url, shifts, allocate_ids=False): super().__init__(client, sheet_id, worksheets, allocate_ids) self.bustime_start = bustime_start self.edit_url = edit_url + self.shifts = shifts self.latest_shifts = common.shifts.parse_shifts(self.shifts)