Walid-Ahmed's picture
Update app.py
aa0a69d verified
raw
history blame
1.6 kB
import gradio as gr
import whisper
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import torch
# Check if GPU is available and set up device map
device_map = "auto" # Automatically balance layers across available devices
print(f"Using ZeRO-powered device map: {device_map}")
# Load the Whisper model using Accelerate with ZeRO
model_name = "tiny" # Change to "base", "small", etc., as needed
print(f"Loading the Whisper model: {model_name} with ZeRO optimization...")
with init_empty_weights():
whisper_model = whisper.load_model(model_name) # Load model structure without weights
# Dispatch the model across devices using ZeRO
whisper_model = load_checkpoint_and_dispatch(
whisper_model,
device_map=device_map,
dtype=torch.float16 # Use mixed precision for efficiency
)
print("Model successfully loaded with ZeRO optimization!")
# Define the transcription function
def transcribe(audio):
# Perform transcription using the Whisper model
result = whisper_model.transcribe(audio)
return result["text"]
# Create the Gradio interface
demo = gr.Interface(
fn=transcribe, # The function to be called for transcription
inputs=gr.Audio(source="microphone", type="filepath", label="Speak into the microphone"), # Input audio
outputs=gr.Textbox(label="Transcription"), # Output transcription
title="Whisper Speech-to-Text with ZeRO", # Title of the interface
description="Record audio using your microphone and get a transcription using the Whisper model optimized with ZeRO."
)
# Launch the Gradio interface
demo.launch()