vericudebuget commited on
Commit
ca365ff
·
verified ·
1 Parent(s): b4bfd19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -67
app.py CHANGED
@@ -1,27 +1,50 @@
 
 
 
 
1
  import streamlit as st
 
2
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
3
- from pydub import AudioSegment
4
  import tempfile
5
- import torch
6
  import os
 
 
7
 
8
- # Set the device to CPU only
9
- device = "cpu"
10
- torch_dtype = torch.float32
 
 
 
 
 
 
 
 
11
 
12
- # Initialize session state
13
- if 'transcription_text' not in st.session_state:
14
- st.session_state.transcription_text = None
15
- if 'srt_content' not in st.session_state:
16
- st.session_state.srt_content = None
 
 
 
 
 
17
 
18
- @st.cache_resource
19
- def load_model():
20
  model_id = "openai/whisper-tiny"
 
21
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
22
- model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=False, use_safetensors=True
23
- ).to(device)
 
 
 
 
 
24
  processor = AutoProcessor.from_pretrained(model_id)
 
25
  pipe = pipeline(
26
  "automatic-speech-recognition",
27
  model=model,
@@ -30,65 +53,64 @@ def load_model():
30
  torch_dtype=torch_dtype,
31
  device=device,
32
  )
 
33
  return pipe
34
 
35
- def format_srt_time(seconds):
36
- hours, remainder = divmod(seconds, 3600)
37
- minutes, seconds = divmod(remainder, 60)
38
- milliseconds = int((seconds % 1) * 1000)
39
- seconds = int(seconds)
40
- return f"{int(hours):02}:{int(minutes):02}:{seconds:02},{milliseconds:03}"
41
-
42
- st.title("Audio/Video Transcription App")
43
-
44
- # Load model
45
- pipe = load_model()
46
-
47
- # File upload
48
- uploaded_file = st.file_uploader("Upload an audio or video file", type=["mp3", "wav", "mp4", "m4a"])
49
 
50
- if uploaded_file is not None:
51
- with st.spinner("Processing audio..."):
52
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio:
53
- # If it's a video, extract audio
54
- if uploaded_file.name.endswith(("mp4", "m4a")):
55
- audio = AudioSegment.from_file(uploaded_file)
56
- audio.export(temp_audio.name, format="wav")
 
 
 
 
 
 
 
57
  else:
58
- audio = AudioSegment.from_file(uploaded_file)
59
- audio.export(temp_audio.name, format="wav")
60
 
61
- # Run the transcription
62
- transcription_result = pipe(temp_audio.name, return_timestamps="word")
 
 
63
 
64
- # Extract text and timestamps
65
- st.session_state.transcription_text = transcription_result['text']
66
- transcription_chunks = transcription_result['chunks']
 
 
 
67
 
68
- # Generate SRT content
69
- srt_content = ""
70
- for i, chunk in enumerate(transcription_chunks, start=1):
71
- start_time = chunk["timestamp"][0]
72
- end_time = chunk["timestamp"][1]
73
- text = chunk["text"]
74
-
75
- srt_content += f"{i}\n"
76
- srt_content += f"{format_srt_time(start_time)} --> {format_srt_time(end_time)}\n"
77
- srt_content += f"{text}\n\n"
78
 
79
- st.session_state.srt_content = srt_content
80
-
81
- # Display transcription
82
- if st.session_state.transcription_text:
83
- st.subheader("Transcription")
84
- st.write(st.session_state.transcription_text)
 
 
 
 
 
 
 
 
85
 
86
- # Provide download for SRT file
87
- if st.session_state.srt_content:
88
- st.subheader("Download SRT File")
89
- st.download_button(
90
- label="Download SRT",
91
- data=st.session_state.srt_content,
92
- file_name="transcription.srt",
93
- mime="text/plain"
94
- )
 
1
+ # requirements.txt
2
+
3
+
4
+ # app.py
5
  import streamlit as st
6
+ import torch
7
  from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
 
8
  import tempfile
 
9
  import os
10
+ from moviepy.editor import VideoFileClip
11
+ import datetime
12
 
13
+ def create_srt(chunks):
14
+ srt_content = ""
15
+ for i, chunk in enumerate(chunks, start=1):
16
+ start_time = str(datetime.timedelta(seconds=chunk['timestamp'][0]))
17
+ end_time = str(datetime.timedelta(seconds=chunk['timestamp'][1]))
18
+ # Ensure proper SRT timestamp format (HH:MM:SS,mmm)
19
+ start_time = start_time.rstrip('0').rstrip('.') + ',000' if '.' in start_time else start_time + ',000'
20
+ end_time = end_time.rstrip('0').rstrip('.') + ',000' if '.' in end_time else end_time + ',000'
21
+
22
+ srt_content += f"{i}\n{start_time} --> {end_time}\n{chunk['text']}\n\n"
23
+ return srt_content
24
 
25
+ def extract_audio(video_path):
26
+ with VideoFileClip(video_path) as video:
27
+ audio = video.audio
28
+ _, temp_audio_path = tempfile.mkstemp(suffix='.mp3')
29
+ audio.write_audiofile(temp_audio_path)
30
+ return temp_audio_path
31
+
32
+ def setup_model():
33
+ device = "cpu"
34
+ torch_dtype = torch.float32
35
 
 
 
36
  model_id = "openai/whisper-tiny"
37
+
38
  model = AutoModelForSpeechSeq2Seq.from_pretrained(
39
+ model_id,
40
+ torch_dtype=torch_dtype,
41
+ low_cpu_mem_usage=True,
42
+ use_safetensors=True
43
+ )
44
+ model.to(device)
45
+
46
  processor = AutoProcessor.from_pretrained(model_id)
47
+
48
  pipe = pipeline(
49
  "automatic-speech-recognition",
50
  model=model,
 
53
  torch_dtype=torch_dtype,
54
  device=device,
55
  )
56
+
57
  return pipe
58
 
59
+ def main():
60
+ st.title("Audio/Video Transcription App")
61
+
62
+ # Initialize session state for model
63
+ if 'pipe' not in st.session_state:
64
+ with st.spinner("Loading model... This might take a few minutes."):
65
+ st.session_state.pipe = setup_model()
 
 
 
 
 
 
 
66
 
67
+ uploaded_file = st.file_uploader("Upload an audio or video file", type=['mp3', 'wav', 'mp4', 'avi', 'mov'])
68
+
69
+ if uploaded_file is not None:
70
+ with st.spinner("Processing file..."):
71
+ # Save uploaded file temporarily
72
+ temp_dir = tempfile.mkdtemp()
73
+ temp_path = os.path.join(temp_dir, uploaded_file.name)
74
+
75
+ with open(temp_path, 'wb') as f:
76
+ f.write(uploaded_file.getvalue())
77
+
78
+ # Extract audio if it's a video file
79
+ if uploaded_file.type.startswith('video'):
80
+ audio_path = extract_audio(temp_path)
81
  else:
82
+ audio_path = temp_path
 
83
 
84
+ # Transcribe
85
+ generate_kwargs = {
86
+ "return_timestamps": True
87
+ }
88
 
89
+ result = st.session_state.pipe(
90
+ audio_path,
91
+ generate_kwargs=generate_kwargs,
92
+ chunk_length_s=30,
93
+ batch_size=8
94
+ )
95
 
96
+ # Display results
97
+ st.subheader("Transcription:")
98
+ st.write(result["text"])
 
 
 
 
 
 
 
99
 
100
+ # Create and offer SRT download
101
+ srt_content = create_srt(result["chunks"])
102
+ st.download_button(
103
+ label="Download SRT file",
104
+ data=srt_content,
105
+ file_name="transcription.srt",
106
+ mime="text/plain"
107
+ )
108
+
109
+ # Cleanup
110
+ os.remove(temp_path)
111
+ if uploaded_file.type.startswith('video'):
112
+ os.remove(audio_path)
113
+ os.rmdir(temp_dir)
114
 
115
+ if __name__ == "__main__":
116
+ main()