|
|
|
|
|
import gradio as gr |
|
import whisper |
|
from accelerate import init_empty_weights, load_checkpoint_and_dispatch |
|
import torch |
|
|
|
|
|
device_map = "auto" |
|
print(f"Using ZeRO-powered device map: {device_map}") |
|
|
|
|
|
model_name = "tiny" |
|
|
|
print(f"Loading the Whisper model: {model_name} with ZeRO optimization...") |
|
with init_empty_weights(): |
|
whisper_model = whisper.load_model(model_name) |
|
|
|
|
|
whisper_model = load_checkpoint_and_dispatch( |
|
whisper_model, |
|
device_map=device_map, |
|
dtype=torch.float16 |
|
) |
|
|
|
print("Model successfully loaded with ZeRO optimization!") |
|
|
|
|
|
def transcribe(audio): |
|
|
|
result = whisper_model.transcribe(audio) |
|
return result["text"] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=transcribe, |
|
inputs=gr.Audio(source="microphone", type="filepath", label="Speak into the microphone"), |
|
outputs=gr.Textbox(label="Transcription"), |
|
title="Whisper Speech-to-Text with ZeRO", |
|
description="Record audio using your microphone and get a transcription using the Whisper model optimized with ZeRO." |
|
) |
|
|
|
|
|
demo.launch() |
|
|