Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
from pydub import AudioSegment | |
import tempfile | |
import torch | |
import os | |
# Set the device to CPU only | |
device = "cpu" | |
torch_dtype = torch.float32 | |
# Initialize session state | |
if 'transcription_text' not in st.session_state: | |
st.session_state.transcription_text = None | |
if 'srt_content' not in st.session_state: | |
st.session_state.srt_content = None | |
def load_model(): | |
model_id = "openai/whisper-large-v3-turbo" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True | |
).to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
return pipe | |
def format_srt_time(seconds): | |
hours, remainder = divmod(seconds, 3600) | |
minutes, seconds = divmod(remainder, 60) | |
milliseconds = int((seconds % 1) * 1000) | |
seconds = int(seconds) | |
return f"{int(hours):02}:{int(minutes):02}:{seconds:02},{milliseconds:03}" | |
st.title("Audio/Video Transcription App") | |
# Load model | |
pipe = load_model() | |
# File upload | |
uploaded_file = st.file_uploader("Upload an audio or video file", type=["mp3", "wav", "mp4", "m4a"]) | |
if uploaded_file is not None: | |
with st.spinner("Processing audio..."): | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: | |
# If it's a video, extract audio | |
if uploaded_file.name.endswith(("mp4", "m4a")): | |
audio = AudioSegment.from_file(uploaded_file) | |
audio.export(temp_audio.name, format="wav") | |
else: | |
audio = AudioSegment.from_file(uploaded_file) | |
audio.export(temp_audio.name, format="wav") | |
# Run the transcription | |
transcription_result = pipe(temp_audio.name, return_timestamps="word") | |
# Extract text and timestamps | |
st.session_state.transcription_text = transcription_result['text'] | |
transcription_chunks = transcription_result['chunks'] | |
# Generate SRT content | |
srt_content = "" | |
for i, chunk in enumerate(transcription_chunks, start=1): | |
start_time = chunk["timestamp"][0] | |
end_time = chunk["timestamp"][1] | |
text = chunk["text"] | |
srt_content += f"{i}\n" | |
srt_content += f"{format_srt_time(start_time)} --> {format_srt_time(end_time)}\n" | |
srt_content += f"{text}\n\n" | |
st.session_state.srt_content = srt_content | |
# Display transcription | |
if st.session_state.transcription_text: | |
st.subheader("Transcription") | |
st.write(st.session_state.transcription_text) | |
# Provide download for SRT file | |
if st.session_state.srt_content: | |
st.subheader("Download SRT File") | |
st.download_button( | |
label="Download SRT", | |
data=st.session_state.srt_content, | |
file_name="transcription.srt", | |
mime="text/plain" | |
) |