import logging import os from celery import Celery from typing import Union, Callable from whisper import tokenizer import tqdm from .util.audio import load_audio logging.basicConfig(format='[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s', level=logging.INFO, force=True) logger = logging.getLogger(__name__) # monkeypatch tqdm to fool whisper's `transcribe` function class _TQDM(tqdm.tqdm): _tqdm = tqdm.tqdm progress_function = None def __init__(self, *argv, total=0, unit="", **kwargs): logger.debug(f"Creating TQDM with total={total}, unit={unit}") self._total = total self._unit = unit self._progress = 0 self.progress_function = _TQDM.progress_function or None super().__init__(*argv, **kwargs) def set_progress_function(progress_function: Callable[[str, int, int], None]): logger.debug(f"Setting progress function to {progress_function}") _TQDM.progress_function = progress_function def update(self, progress): logger.debug(f"Updating TQDM with progress={progress}") self._progress += progress if self.progress_function is not None: self.progress_function(self._unit, self._total, self._progress) else: _TQDM._tqdm.update(self, progress) tqdm.tqdm = _TQDM ASR_ENGINE = os.getenv("ASR_ENGINE", "faster_whisper") if ASR_ENGINE == "faster_whisper": from .faster_whisper.core import load_model, transcribe as whisper_transcribe else: from .openai_whisper.core import load_model, transcribe as whisper_transcribe LANGUAGE_CODES = sorted(list(tokenizer.LANGUAGES.keys())) DEFAULT_MODEL_NAME = os.getenv("ASR_MODEL", "small") STATES = { 'loading_model': 'LOADING_MODEL', 'encoding': 'ENCODING', 'transcribing': 'TRANSCRIBING', } celery = Celery(__name__) celery.conf.broker_connection_retry_on_startup = True celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379") celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379") celery.conf.worker_hijack_root_logger = False celery.conf.worker_redirect_stdouts_level = "DEBUG" @celery.task(name="transcribe", bind=True) def transcribe( self, audio_file_path: str, original_filename: str, asr_options: dict, ): logger.info(f"Transcribing {audio_file_path} with {asr_options}") output_format = asr_options["output"] with open(audio_file_path, "rb") as audio_file: _TQDM.set_progress_function(update_progress(self)) try: model_name = asr_options.get("model_name") or DEFAULT_MODEL_NAME logger.info(f"Loading model {model_name}") self.update_state(state=STATES["loading_model"], meta={"progress": {"units": "models", "total": 1, "current": 0}}) load_model(model_name) logger.info(f"Loading audio from {audio_file_path}") self.update_state(state=STATES["encoding"], meta={"progress": {"units": "files", "total": 1, "current": 0}}) audio_data = load_audio(audio_file, asr_options.get("encode", False)) logger.info(f"Transcribing audio") self.update_state(state=STATES["transcribing"], meta={"progress": {"units": "files", "total": 1, "current": 0}}) result = whisper_transcribe(audio_data, asr_options, output_format) finally: _TQDM.set_progress_function(None) logger.info(f"Transcription complete") os.remove(audio_file_path) filename = f"{original_filename.encode('latin-1', 'ignore').decode()}.{output_format}" output_directory = get_output_path(self.request.id) output_path = f"{output_directory}/{filename}" logger.info(f"Writing result to {output_path}") if not os.path.exists(output_directory): os.makedirs(output_directory) with open(output_path, "w") as f: f.write(result.read()) url_path = f"{get_output_url_path(transcribe.request.id)}/{filename}" return { "output_filename": filename, "output_path": output_path, "url_path": url_path, } def get_output_path(job_id: str): return os.environ.get("OUTPUT_DIRECTORY", os.getcwd() + "/app/output") + "/" + job_id def get_output_url_path(job_id: str): return os.environ.get("OUTPUT_URL_PREFIX", "/output") + "/" + job_id def update_progress(context): def do_update(units, total, current): logger.info(f"Updating progress with units={units}, total={total}, current={current}") context.update_state( state=STATES["transcribing"], meta={"progress": {"units": units, "total": total, "current": current}} ) return do_update