File size: 3,688 Bytes
ca365ff
 
 
 
dc800de
ca365ff
dc800de
 
 
ca365ff
 
dc800de
ca365ff
 
 
 
 
 
 
 
 
 
 
dc800de
ca365ff
 
 
 
 
 
 
 
 
 
dc800de
b4bfd19
ca365ff
b45ed63
ca365ff
 
 
 
 
 
 
b45ed63
ca365ff
b45ed63
 
 
 
 
 
 
 
ca365ff
b45ed63
 
ca365ff
 
 
 
 
 
 
dc800de
ca365ff
 
 
 
 
 
 
 
 
 
 
 
 
 
b45ed63
ca365ff
b45ed63
ca365ff
 
 
 
b45ed63
ca365ff
 
 
 
 
 
b45ed63
ca365ff
 
 
b45ed63
ca365ff
 
 
 
 
 
 
 
 
 
 
 
 
 
dc800de
ca365ff
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# requirements.txt


# app.py
import streamlit as st
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import tempfile
import os
from moviepy.editor import VideoFileClip
import datetime

def create_srt(chunks):
    srt_content = ""
    for i, chunk in enumerate(chunks, start=1):
        start_time = str(datetime.timedelta(seconds=chunk['timestamp'][0]))
        end_time = str(datetime.timedelta(seconds=chunk['timestamp'][1]))
        # Ensure proper SRT timestamp format (HH:MM:SS,mmm)
        start_time = start_time.rstrip('0').rstrip('.') + ',000' if '.' in start_time else start_time + ',000'
        end_time = end_time.rstrip('0').rstrip('.') + ',000' if '.' in end_time else end_time + ',000'
        
        srt_content += f"{i}\n{start_time} --> {end_time}\n{chunk['text']}\n\n"
    return srt_content

def extract_audio(video_path):
    with VideoFileClip(video_path) as video:
        audio = video.audio
        _, temp_audio_path = tempfile.mkstemp(suffix='.mp3')
        audio.write_audiofile(temp_audio_path)
    return temp_audio_path

def setup_model():
    device = "cpu"
    torch_dtype = torch.float32

    model_id = "openai/whisper-tiny"

    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id, 
        torch_dtype=torch_dtype, 
        low_cpu_mem_usage=True, 
        use_safetensors=True
    )
    model.to(device)

    processor = AutoProcessor.from_pretrained(model_id)

    pipe = pipeline(
        "automatic-speech-recognition",
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        torch_dtype=torch_dtype,
        device=device,
    )
    
    return pipe

def main():
    st.title("Audio/Video Transcription App")
    
    # Initialize session state for model
    if 'pipe' not in st.session_state:
        with st.spinner("Loading model... This might take a few minutes."):
            st.session_state.pipe = setup_model()

    uploaded_file = st.file_uploader("Upload an audio or video file", type=['mp3', 'wav', 'mp4', 'avi', 'mov'])
    
    if uploaded_file is not None:
        with st.spinner("Processing file..."):
            # Save uploaded file temporarily
            temp_dir = tempfile.mkdtemp()
            temp_path = os.path.join(temp_dir, uploaded_file.name)
            
            with open(temp_path, 'wb') as f:
                f.write(uploaded_file.getvalue())
            
            # Extract audio if it's a video file
            if uploaded_file.type.startswith('video'):
                audio_path = extract_audio(temp_path)
            else:
                audio_path = temp_path
            
            # Transcribe
            generate_kwargs = {
                "return_timestamps": True
            }
            
            result = st.session_state.pipe(
                audio_path, 
                generate_kwargs=generate_kwargs, 
                chunk_length_s=30, 
                batch_size=8
            )
            
            # Display results
            st.subheader("Transcription:")
            st.write(result["text"])
            
            # Create and offer SRT download
            srt_content = create_srt(result["chunks"])
            st.download_button(
                label="Download SRT file",
                data=srt_content,
                file_name="transcription.srt",
                mime="text/plain"
            )
            
            # Cleanup
            os.remove(temp_path)
            if uploaded_file.type.startswith('video'):
                os.remove(audio_path)
            os.rmdir(temp_dir)

if __name__ == "__main__":
    main()