ReaSpeech-Cloud / app /worker.py
j
initial commit
402daee
import logging
import os
from celery import Celery
from typing import Union, Callable
from whisper import tokenizer
import tqdm
from .util.audio import load_audio
logging.basicConfig(format='[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s', level=logging.INFO, force=True)
logger = logging.getLogger(__name__)
# monkeypatch tqdm to fool whisper's `transcribe` function
class _TQDM(tqdm.tqdm):
_tqdm = tqdm.tqdm
progress_function = None
def __init__(self, *argv, total=0, unit="", **kwargs):
logger.debug(f"Creating TQDM with total={total}, unit={unit}")
self._total = total
self._unit = unit
self._progress = 0
self.progress_function = _TQDM.progress_function or None
super().__init__(*argv, **kwargs)
def set_progress_function(progress_function: Callable[[str, int, int], None]):
logger.debug(f"Setting progress function to {progress_function}")
_TQDM.progress_function = progress_function
def update(self, progress):
logger.debug(f"Updating TQDM with progress={progress}")
self._progress += progress
if self.progress_function is not None:
self.progress_function(self._unit, self._total, self._progress)
else:
_TQDM._tqdm.update(self, progress)
tqdm.tqdm = _TQDM
ASR_ENGINE = os.getenv("ASR_ENGINE", "faster_whisper")
if ASR_ENGINE == "faster_whisper":
from .faster_whisper.core import load_model, transcribe as whisper_transcribe
else:
from .openai_whisper.core import load_model, transcribe as whisper_transcribe
LANGUAGE_CODES = sorted(list(tokenizer.LANGUAGES.keys()))
DEFAULT_MODEL_NAME = os.getenv("ASR_MODEL", "small")
STATES = {
'loading_model': 'LOADING_MODEL',
'encoding': 'ENCODING',
'transcribing': 'TRANSCRIBING',
}
celery = Celery(__name__)
celery.conf.broker_connection_retry_on_startup = True
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379")
celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379")
celery.conf.worker_hijack_root_logger = False
celery.conf.worker_redirect_stdouts_level = "DEBUG"
@celery.task(name="transcribe", bind=True)
def transcribe(
self,
audio_file_path: str,
original_filename: str,
asr_options: dict,
):
logger.info(f"Transcribing {audio_file_path} with {asr_options}")
output_format = asr_options["output"]
with open(audio_file_path, "rb") as audio_file:
_TQDM.set_progress_function(update_progress(self))
try:
model_name = asr_options.get("model_name") or DEFAULT_MODEL_NAME
logger.info(f"Loading model {model_name}")
self.update_state(state=STATES["loading_model"], meta={"progress": {"units": "models", "total": 1, "current": 0}})
load_model(model_name)
logger.info(f"Loading audio from {audio_file_path}")
self.update_state(state=STATES["encoding"], meta={"progress": {"units": "files", "total": 1, "current": 0}})
audio_data = load_audio(audio_file, asr_options.get("encode", False))
logger.info(f"Transcribing audio")
self.update_state(state=STATES["transcribing"], meta={"progress": {"units": "files", "total": 1, "current": 0}})
result = whisper_transcribe(audio_data, asr_options, output_format)
finally:
_TQDM.set_progress_function(None)
logger.info(f"Transcription complete")
os.remove(audio_file_path)
filename = f"{original_filename.encode('latin-1', 'ignore').decode()}.{output_format}"
output_directory = get_output_path(self.request.id)
output_path = f"{output_directory}/{filename}"
logger.info(f"Writing result to {output_path}")
if not os.path.exists(output_directory):
os.makedirs(output_directory)
with open(output_path, "w") as f:
f.write(result.read())
url_path = f"{get_output_url_path(transcribe.request.id)}/{filename}"
return {
"output_filename": filename,
"output_path": output_path,
"url_path": url_path,
}
def get_output_path(job_id: str):
return os.environ.get("OUTPUT_DIRECTORY", os.getcwd() + "/app/output") + "/" + job_id
def get_output_url_path(job_id: str):
return os.environ.get("OUTPUT_URL_PREFIX", "/output") + "/" + job_id
def update_progress(context):
def do_update(units, total, current):
logger.info(f"Updating progress with units={units}, total={total}, current={current}")
context.update_state(
state=STATES["transcribing"],
meta={"progress": {"units": units, "total": total, "current": current}}
)
return do_update