From 11164271b288da40aa77bf9913bdc61cc86800eb Mon Sep 17 00:00:00 2001 From: HeNine <> Date: Mon, 20 Sep 2021 20:36:21 +0200 Subject: [PATCH] Gap handling fixed --- buscribe/buscribe/buscribe.py | 8 ++++---- buscribe/buscribe/main.py | 5 ++++- buscribe/buscribe/recognizer.py | 22 ++++++++++++++++------ 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/buscribe/buscribe/buscribe.py b/buscribe/buscribe/buscribe.py index 7c2a764..32adf02 100644 --- a/buscribe/buscribe/buscribe.py +++ b/buscribe/buscribe/buscribe.py @@ -43,8 +43,8 @@ def transcribe_segments(segments: list, sample_rate: int, recognizer: BuscribeRe data = process.stdout.read(16000) if len(data) == 0: break - if recognizer.AcceptWaveform(data): - result_json = json.loads(recognizer.Result()) + if recognizer.accept_waveform(data): + result_json = json.loads(recognizer.result()) logging.debug(json.dumps(result_json, indent=2)) if result_json["text"] == "": @@ -92,11 +92,11 @@ def get_end_of_transcript(db_cursor): def finish_off_recognizer(recognizer: BuscribeRecognizer, db_cursor): """Flush the recognizer, commit the final line to the database and reset it.""" - final_result_json = json.loads(recognizer.FinalResult()) # Flush the tubes + final_result_json = json.loads(recognizer.final_result()) # Flush the tubes line_start_time = recognizer.segments_start_time + timedelta(seconds=final_result_json["result"][0]["start"]) line_end_time = recognizer.segments_start_time + timedelta(seconds=final_result_json["result"][-1]["end"]) write_line(final_result_json, line_start_time, line_end_time, db_cursor) - recognizer.Reset() + recognizer.reset() diff --git a/buscribe/buscribe/main.py b/buscribe/buscribe/main.py index 281f07c..a4ebe2a 100644 --- a/buscribe/buscribe/main.py +++ b/buscribe/buscribe/main.py @@ -63,6 +63,7 @@ def main(database="", base_dir=".", logging.info('Transcribing from {}'.format(start_time)) # Start priming the recognizer if possible + start_of_transcription = start_time start_time -= timedelta(minutes=2) stopping = gevent.event.Event() @@ -85,8 +86,10 @@ def main(database="", base_dir=".", if recognizer.segments_start_time is None: recognizer.segments_start_time = segments[0].start + logging.info(f"Starting from: {segments[0].start}") - segments_end_time = transcribe_segments(segments, SAMPLE_RATE, recognizer, start_time, db_cursor, stopping) + segments_end_time = transcribe_segments(segments, SAMPLE_RATE, recognizer, start_of_transcription, db_cursor, + stopping) if end_time is not None and segments_end_time >= end_time \ or stopping.is_set(): diff --git a/buscribe/buscribe/recognizer.py b/buscribe/buscribe/recognizer.py index 888c77e..075820e 100644 --- a/buscribe/buscribe/recognizer.py +++ b/buscribe/buscribe/recognizer.py @@ -1,7 +1,7 @@ from vosk import Model, SpkModel, KaldiRecognizer -class BuscribeRecognizer(KaldiRecognizer): +class BuscribeRecognizer(): segments_start_time = None def __init__(self, sample_rate=48000, model_path="model_small", spk_model_path="spk_model"): @@ -11,13 +11,23 @@ class BuscribeRecognizer(KaldiRecognizer): Returns a recognizer object. """ + self.sample_rate = sample_rate self.model = Model(model_path) self.spk_model = SpkModel(spk_model_path) - super(BuscribeRecognizer, self).__init__(self.model, sample_rate, self.spk_model) + self.recognizer = KaldiRecognizer(self.model, self.sample_rate, self.spk_model) + self.recognizer.SetWords(True) - self.SetWords(True) - - def Reset(self): - super(BuscribeRecognizer, self).Reset() + def reset(self): + self.recognizer = KaldiRecognizer(self.model, self.sample_rate, self.spk_model) + self.recognizer.SetWords(True) self.segments_start_time = None + + def accept_waveform(self, data): + return self.recognizer.AcceptWaveform(data) + + def result(self): + return self.recognizer.Result() + + def final_result(self): + return self.recognizer.FinalResult()