File size: 4,706 Bytes
402daee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
import os

from celery import Celery
from typing import Union, Callable
from whisper import tokenizer
import tqdm

from .util.audio import load_audio

logging.basicConfig(format='[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s', level=logging.INFO, force=True)
logger = logging.getLogger(__name__)

# monkeypatch tqdm to fool whisper's `transcribe` function
class _TQDM(tqdm.tqdm):
    _tqdm = tqdm.tqdm
    progress_function = None

    def __init__(self, *argv, total=0, unit="", **kwargs):
        logger.debug(f"Creating TQDM with total={total}, unit={unit}")
        self._total = total
        self._unit = unit
        self._progress = 0
        self.progress_function = _TQDM.progress_function or None
        super().__init__(*argv, **kwargs)

    def set_progress_function(progress_function: Callable[[str, int, int], None]):
        logger.debug(f"Setting progress function to {progress_function}")
        _TQDM.progress_function = progress_function

    def update(self, progress):
        logger.debug(f"Updating TQDM with progress={progress}")
        self._progress += progress
        if self.progress_function is not None:
            self.progress_function(self._unit, self._total, self._progress)
        else:
            _TQDM._tqdm.update(self, progress)

tqdm.tqdm = _TQDM

ASR_ENGINE = os.getenv("ASR_ENGINE", "faster_whisper")
if ASR_ENGINE == "faster_whisper":
    from .faster_whisper.core import load_model, transcribe as whisper_transcribe
else:
    from .openai_whisper.core import load_model, transcribe as whisper_transcribe

LANGUAGE_CODES = sorted(list(tokenizer.LANGUAGES.keys()))

DEFAULT_MODEL_NAME = os.getenv("ASR_MODEL", "small")

STATES = {
    'loading_model': 'LOADING_MODEL',
    'encoding': 'ENCODING',
    'transcribing': 'TRANSCRIBING',
}
celery = Celery(__name__)
celery.conf.broker_connection_retry_on_startup = True
celery.conf.broker_url = os.environ.get("CELERY_BROKER_URL", "redis://localhost:6379")
celery.conf.result_backend = os.environ.get("CELERY_RESULT_BACKEND", "redis://localhost:6379")
celery.conf.worker_hijack_root_logger = False
celery.conf.worker_redirect_stdouts_level = "DEBUG"

@celery.task(name="transcribe", bind=True)
def transcribe(
    self,
    audio_file_path: str,
    original_filename: str,
    asr_options: dict,
):
    logger.info(f"Transcribing {audio_file_path} with {asr_options}")
    output_format = asr_options["output"]

    with open(audio_file_path, "rb") as audio_file:
        _TQDM.set_progress_function(update_progress(self))

        try:
            model_name = asr_options.get("model_name") or DEFAULT_MODEL_NAME
            logger.info(f"Loading model {model_name}")
            self.update_state(state=STATES["loading_model"], meta={"progress": {"units": "models", "total": 1, "current": 0}})
            load_model(model_name)

            logger.info(f"Loading audio from {audio_file_path}")
            self.update_state(state=STATES["encoding"], meta={"progress": {"units": "files", "total": 1, "current": 0}})
            audio_data = load_audio(audio_file, asr_options.get("encode", False))

            logger.info(f"Transcribing audio")
            self.update_state(state=STATES["transcribing"], meta={"progress": {"units": "files", "total": 1, "current": 0}})
            result = whisper_transcribe(audio_data, asr_options, output_format)
        finally:
            _TQDM.set_progress_function(None)

    logger.info(f"Transcription complete")

    os.remove(audio_file_path)

    filename = f"{original_filename.encode('latin-1', 'ignore').decode()}.{output_format}"
    output_directory = get_output_path(self.request.id)
    output_path = f"{output_directory}/{filename}"

    logger.info(f"Writing result to {output_path}")

    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    with open(output_path, "w") as f:
        f.write(result.read())

    url_path = f"{get_output_url_path(transcribe.request.id)}/{filename}"

    return {
        "output_filename": filename,
        "output_path": output_path,
        "url_path": url_path,
    }

def get_output_path(job_id: str):
    return os.environ.get("OUTPUT_DIRECTORY", os.getcwd() + "/app/output") + "/" + job_id

def get_output_url_path(job_id: str):
    return os.environ.get("OUTPUT_URL_PREFIX", "/output") + "/" + job_id

def update_progress(context):
    def do_update(units, total, current):
        logger.info(f"Updating progress with units={units}, total={total}, current={current}")
        context.update_state(
            state=STATES["transcribing"],
            meta={"progress": {"units": units, "total": total, "current": current}}
        )
    return do_update