File size: 5,856 Bytes
ff71374
8bbb796
bdc2933
166c454
8bbb796
 
ff71374
166c454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e6ab01
 
166c454
 
 
 
5e6ab01
 
 
bdc2933
 
 
 
 
166c454
bdc2933
 
 
5e6ab01
 
166c454
bdc2933
 
 
 
166c454
c318bd7
bdc2933
 
5e6ab01
 
166c454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdc2933
 
 
 
 
 
166c454
bdc2933
5e6ab01
bdc2933
 
5e6ab01
bdc2933
5e6ab01
bdc2933
8bbb796
bdc2933
 
5d20e7d
bdc2933
 
 
166c454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bbb796
 
 
 
 
 
 
bdc2933
 
 
8bbb796
 
bdc2933
166c454
5e6ab01
 
 
 
 
 
 
 
 
 
ff71374
5e6ab01
 
 
 
 
8bbb796
166c454
 
 
 
 
 
 
 
 
8bbb796
 
 
 
bdc2933
166c454
 
 
 
 
 
 
bdc2933
8bbb796
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import gradio as gr
import spaces
from huggingface_hub import get_collection, HfApi
from transformers import pipeline, Pipeline

is_hf_space = os.getenv("IS_HF_SPACE")


def get_dropdown_model_ids():
    mozilla_ai_model_ids = []
    # Get model ids from collection and append the language in () from the model's metadata
    for model_i in get_collection(
        "mozilla-ai/common-voice-whisper-67b847a74ad7561781aa10fd"
    ).items:
        model_metadata = HfApi().model_info(model_i.item_id)
        language = model_metadata.card_data.model_name.split("on ")[1]
        mozilla_ai_model_ids.append(model_i.item_id + f" ({language})")

    return (
        [""]
        + mozilla_ai_model_ids
        + [
            "openai/whisper-tiny (Multilingual)",
            "openai/whisper-small (Multilingual)",
            "openai/whisper-medium (Multilingual)",
            "openai/whisper-large-v3 (Multilingual)",
            "openai/whisper-large-v3-turbo (Multilingual)",
        ]
    )


def _load_local_model(model_dir: str) -> Pipeline | str:
    from transformers import WhisperProcessor, WhisperForConditionalGeneration

    processor = WhisperProcessor.from_pretrained(model_dir)
    model = WhisperForConditionalGeneration.from_pretrained(model_dir)

    try:
        return pipeline(
            task="automatic-speech-recognition",
            model=model,
            processor=processor,
            chunk_length_s=30,  # max input duration for whisper
        )
    except Exception as e:
        return str(e)


def _load_hf_model(model_repo_id: str) -> Pipeline | str:
    try:
        return pipeline(
            "automatic-speech-recognition",
            model=model_repo_id,
            chunk_length_s=30,  # max input duration for whisper
        )
    except Exception as e:
        return str(e)


# Copied from https://github.com/openai/whisper/blob/517a43ecd132a2089d85f4ebc044728a71d49f6e/whisper/utils.py#L50
def format_timestamp(
    seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
    assert seconds >= 0, "non-negative timestamp expected"
    milliseconds = round(seconds * 1000.0)

    hours = milliseconds // 3_600_000
    milliseconds -= hours * 3_600_000

    minutes = milliseconds // 60_000
    milliseconds -= minutes * 60_000

    seconds = milliseconds // 1_000
    milliseconds -= seconds * 1_000

    hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
    return (
        f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
    )


@spaces.GPU(duration=30)
def transcribe(
    dropdown_model_id: str,
    hf_model_id: str,
    local_model_id: str,
    audio: gr.Audio,
    show_timestamps: bool,
) -> str:
    if dropdown_model_id and not hf_model_id and not local_model_id:
        dropdown_model_id = dropdown_model_id.split(" (")[0]
        pipe = _load_hf_model(dropdown_model_id)
    elif hf_model_id and not local_model_id and not dropdown_model_id:
        pipe = _load_hf_model(hf_model_id)
    elif local_model_id and not hf_model_id and not dropdown_model_id:
        pipe = _load_local_model(local_model_id)
    else:
        return (
            "⚠️ Error: Please select or fill at least and only one of the options above"
        )
    if isinstance(pipe, str):
        # Exception raised when loading
        return f"⚠️ Error: {pipe}"

    output = pipe(
        audio,
        generate_kwargs={"task": "transcribe"},
        batch_size=16,
        return_timestamps=show_timestamps,
    )
    text = output["text"]
    if show_timestamps:
        timestamps = output["chunks"]
        timestamps = [
            f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
            for chunk in timestamps
        ]
        text = "\n".join(str(feature) for feature in timestamps)
    return text


def setup_gradio_demo():
    with gr.Blocks() as demo:
        gr.Markdown(
            """ # 🗣️ Speech-to-Text Transcription
            ### 1. Select which model to use from one of the options below.
            ### 2. Record a message or upload an audio file.
            ### 3. Click Transcribe to see the transcription generated by the model.
            """
        )
        ### Model selection ###
        model_ids = get_dropdown_model_ids()
        with gr.Row():
            with gr.Column():
                dropdown_model = gr.Dropdown(
                    choices=model_ids, label="Option 1: Select a model"
                )
            with gr.Column():
                user_model = gr.Textbox(
                    label="Option 2: Paste HF model id",
                    placeholder="my-username/my-whisper-tiny",
                )
            with gr.Column(visible=not is_hf_space):
                local_model = gr.Textbox(
                    label="Option 3: Paste local path to model directory",
                    placeholder="artifacts/my-whisper-tiny",
                )

        ### Transcription ###
        with gr.Group():
            audio_input = gr.Audio(
                sources=["microphone", "upload"],
                type="filepath",
                label="Record a message / Upload audio file",
                show_download_button=True,
            )
            timestamps_check = gr.Checkbox(label="Show timestamps")

        transcribe_button = gr.Button("Transcribe")
        transcribe_output = gr.Text(label="Output")

        transcribe_button.click(
            fn=transcribe,
            inputs=[
                dropdown_model,
                user_model,
                local_model,
                audio_input,
                timestamps_check,
            ],
            outputs=transcribe_output,
        )

    demo.launch()


if __name__ == "__main__":
    setup_gradio_demo()