Spaces:
Runtime error
Runtime error
# requirements.txt | |
# app.py | |
import streamlit as st | |
import torch | |
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
import tempfile | |
import os | |
from moviepy.editor import VideoFileClip | |
import datetime | |
def create_srt(chunks): | |
srt_content = "" | |
for i, chunk in enumerate(chunks, start=1): | |
start_time = str(datetime.timedelta(seconds=chunk['timestamp'][0])) | |
end_time = str(datetime.timedelta(seconds=chunk['timestamp'][1])) | |
# Ensure proper SRT timestamp format (HH:MM:SS,mmm) | |
start_time = start_time.rstrip('0').rstrip('.') + ',000' if '.' in start_time else start_time + ',000' | |
end_time = end_time.rstrip('0').rstrip('.') + ',000' if '.' in end_time else end_time + ',000' | |
srt_content += f"{i}\n{start_time} --> {end_time}\n{chunk['text']}\n\n" | |
return srt_content | |
def extract_audio(video_path): | |
with VideoFileClip(video_path) as video: | |
audio = video.audio | |
_, temp_audio_path = tempfile.mkstemp(suffix='.mp3') | |
audio.write_audiofile(temp_audio_path) | |
return temp_audio_path | |
def setup_model(): | |
device = "cpu" | |
torch_dtype = torch.float32 | |
model_id = "openai/whisper-tiny" | |
model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
model_id, | |
torch_dtype=torch_dtype, | |
low_cpu_mem_usage=True, | |
use_safetensors=True | |
) | |
model.to(device) | |
processor = AutoProcessor.from_pretrained(model_id) | |
pipe = pipeline( | |
"automatic-speech-recognition", | |
model=model, | |
tokenizer=processor.tokenizer, | |
feature_extractor=processor.feature_extractor, | |
torch_dtype=torch_dtype, | |
device=device, | |
) | |
return pipe | |
def main(): | |
st.title("Audio/Video Transcription App") | |
# Initialize session state for model | |
if 'pipe' not in st.session_state: | |
with st.spinner("Loading model... This might take a few minutes."): | |
st.session_state.pipe = setup_model() | |
uploaded_file = st.file_uploader("Upload an audio or video file", type=['mp3', 'wav', 'mp4', 'avi', 'mov']) | |
if uploaded_file is not None: | |
with st.spinner("Processing file..."): | |
# Save uploaded file temporarily | |
temp_dir = tempfile.mkdtemp() | |
temp_path = os.path.join(temp_dir, uploaded_file.name) | |
with open(temp_path, 'wb') as f: | |
f.write(uploaded_file.getvalue()) | |
# Extract audio if it's a video file | |
if uploaded_file.type.startswith('video'): | |
audio_path = extract_audio(temp_path) | |
else: | |
audio_path = temp_path | |
# Transcribe | |
generate_kwargs = { | |
"return_timestamps": True | |
} | |
result = st.session_state.pipe( | |
audio_path, | |
generate_kwargs=generate_kwargs, | |
chunk_length_s=30, | |
batch_size=8 | |
) | |
# Display results | |
st.subheader("Transcription:") | |
st.write(result["text"]) | |
# Create and offer SRT download | |
srt_content = create_srt(result["chunks"]) | |
st.download_button( | |
label="Download SRT file", | |
data=srt_content, | |
file_name="transcription.srt", | |
mime="text/plain" | |
) | |
# Cleanup | |
os.remove(temp_path) | |
if uploaded_file.type.startswith('video'): | |
os.remove(audio_path) | |
os.rmdir(temp_dir) | |
if __name__ == "__main__": | |
main() |