speechtotextt / app.py
alaahilal's picture
updated the model
4af3d61 verified
raw
history blame
6.44 kB
import streamlit as st
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import wave
import numpy as np
import tempfile
import os
# Page configuration
st.set_page_config(
page_title="Speech to Text Converter",
page_icon="πŸŽ™οΈ",
layout="wide"
)
@st.cache_resource
def load_pipeline():
"""Load the model, processor, and create pipeline"""
device = "cpu"
torch_dtype = torch.float32
model_id = "distil-whisper/distil-large-v3"
# Load model
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
# Load processor
processor = AutoProcessor.from_pretrained(model_id)
# Create pipeline
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=8,
torch_dtype=torch_dtype,
device=device,
)
return pipe
def read_wav_file(wav_file):
"""Read WAV file using wave library"""
with wave.open(wav_file, 'rb') as wav:
# Get wav file parameters
channels = wav.getnchannels()
sample_width = wav.getsampwidth()
sample_rate = wav.getframerate()
n_frames = wav.getnframes()
# Read raw audio data
raw_data = wav.readframes(n_frames)
# Convert bytes to numpy array
if sample_width == 1:
dtype = np.uint8
elif sample_width == 2:
dtype = np.int16
else:
raise ValueError("Unsupported sample width")
audio_data = np.frombuffer(raw_data, dtype=dtype)
# Convert to float32 and normalize
audio_data = audio_data.astype(np.float32) / np.iinfo(dtype).max
# If stereo, convert to mono by averaging channels
if channels == 2:
audio_data = audio_data.reshape(-1, 2).mean(axis=1)
# Resample to 16kHz if necessary
if sample_rate != 16000:
# Simple resampling
original_length = len(audio_data)
desired_length = int(original_length * 16000 / sample_rate)
indices = np.linspace(0, original_length-1, desired_length)
audio_data = np.interp(indices, np.arange(original_length), audio_data)
return audio_data
def main():
st.title("πŸŽ™οΈ Speech to Text Converter")
st.markdown("### Upload a WAV file and convert speech to text")
# Load pipeline
with st.spinner("Loading model... This might take a few minutes the first time."):
try:
pipe = load_pipeline()
st.success("Model loaded successfully! Ready to transcribe.")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
return
# File upload
audio_file = st.file_uploader(
"Upload your audio file",
type=['wav'],
help="Only WAV files are supported. For better performance, keep files under 5 minutes."
)
if audio_file is not None:
# Create a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
tmp_file.write(audio_file.getvalue())
temp_path = tmp_file.name
try:
# Display audio player
st.audio(audio_file)
# Add transcribe button
if st.button("🎯 Transcribe Audio", type="primary"):
progress_bar = st.progress(0)
status_text = st.empty()
try:
# Read audio file
status_text.text("Reading audio file...")
progress_bar.progress(25)
audio_data = read_wav_file(temp_path)
# Transcribe
status_text.text("Transcribing... This might take a while.")
progress_bar.progress(50)
# Use pipeline for transcription
result = pipe(
{"raw": audio_data, "sampling_rate": 16000},
return_timestamps=True
)
# Update progress
progress_bar.progress(100)
status_text.text("Transcription completed!")
# Display results
st.markdown("### Transcription Result:")
st.write(result["text"])
# Display timestamps if available
if "chunks" in result:
st.markdown("### Timestamps:")
for chunk in result["chunks"]:
st.write(f"{chunk['timestamp']}: {chunk['text']}")
# Download button
st.download_button(
label="πŸ“₯ Download Transcription",
data=result["text"],
file_name="transcription.txt",
mime="text/plain"
)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
finally:
# Clean up temporary file
if os.path.exists(temp_path):
os.remove(temp_path)
# Usage instructions
with st.expander("ℹ️ Usage Instructions"):
st.markdown("""
### Instructions:
1. Upload a WAV file (16-bit PCM format recommended)
2. Click 'Transcribe Audio'
3. Wait for processing to complete
4. View or download the transcription
### Notes:
- Only WAV files are supported
- Keep files under 5 minutes for best results
- Audio should be clear with minimal background noise
- The transcription includes timestamps for better reference
""")
# Footer
st.markdown("---")
st.markdown(
"Made with ❀️ using Distil-Whisper model"
)
if __name__ == "__main__":
main()