import streamlit as st from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from pydub import AudioSegment import tempfile import torch import os # Set the device to CPU only device = "cpu" torch_dtype = torch.float32 # Initialize session state if 'transcription_text' not in st.session_state: st.session_state.transcription_text = None if 'srt_content' not in st.session_state: st.session_state.srt_content = None @st.cache_resource def load_model(): model_id = "openai/whisper-large-v3-turbo" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ).to(device) processor = AutoProcessor.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, torch_dtype=torch_dtype, device=device, ) return pipe def format_srt_time(seconds): hours, remainder = divmod(seconds, 3600) minutes, seconds = divmod(remainder, 60) milliseconds = int((seconds % 1) * 1000) seconds = int(seconds) return f"{int(hours):02}:{int(minutes):02}:{seconds:02},{milliseconds:03}" st.title("Audio/Video Transcription App") # Load model pipe = load_model() # File upload uploaded_file = st.file_uploader("Upload an audio or video file", type=["mp3", "wav", "mp4", "m4a"]) if uploaded_file is not None: with st.spinner("Processing audio..."): with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio: # If it's a video, extract audio if uploaded_file.name.endswith(("mp4", "m4a")): audio = AudioSegment.from_file(uploaded_file) audio.export(temp_audio.name, format="wav") else: audio = AudioSegment.from_file(uploaded_file) audio.export(temp_audio.name, format="wav") # Run the transcription transcription_result = pipe(temp_audio.name, return_timestamps="word") # Extract text and timestamps st.session_state.transcription_text = transcription_result['text'] transcription_chunks = transcription_result['chunks'] # Generate SRT content srt_content = "" for i, chunk in enumerate(transcription_chunks, start=1): start_time = chunk["timestamp"][0] end_time = chunk["timestamp"][1] text = chunk["text"] srt_content += f"{i}\n" srt_content += f"{format_srt_time(start_time)} --> {format_srt_time(end_time)}\n" srt_content += f"{text}\n\n" st.session_state.srt_content = srt_content # Display transcription if st.session_state.transcription_text: st.subheader("Transcription") st.write(st.session_state.transcription_text) # Provide download for SRT file if st.session_state.srt_content: st.subheader("Download SRT File") st.download_button( label="Download SRT", data=st.session_state.srt_content, file_name="transcription.srt", mime="text/plain" )