abdullah-alnahas commited on
Commit
356f877
·
1 Parent(s): 91a85ec

feat(app.py): add download video and audio options

Browse files
Files changed (1) hide show
  1. app.py +261 -104
app.py CHANGED
@@ -9,122 +9,112 @@ from silero_vad import load_silero_vad, get_speech_timestamps
9
  import numpy as np
10
  import pydub
11
 
12
- VAD_SENSITIVITY = 0.1
13
-
14
  # --- Model Loading and Caching ---
15
  @st.cache_resource
16
  def load_transcriber(_device):
 
17
  transcriber = pipeline(model="openai/whisper-large-v3-turbo", device=_device)
18
  return transcriber
19
 
20
  @st.cache_resource
21
  def load_vad_model():
 
22
  return load_silero_vad()
23
 
24
  # --- Audio Processing Functions ---
25
  @st.cache_resource
26
- def download_and_convert_audio(video_url):
 
 
 
 
 
 
 
 
 
27
  status_message = st.empty()
28
  status_message.text("Downloading audio...")
29
  try:
30
  ydl_opts = {
31
- 'format': 'bestaudio/best',
32
  'postprocessors': [{
33
  'key': 'FFmpegExtractAudio',
34
- 'preferredcodec': 'wav',
35
- 'preferredquality': '192',
36
  }],
37
  'outtmpl': '%(id)s.%(ext)s',
 
 
38
  }
39
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
40
  info = ydl.extract_info(video_url, download=False)
 
 
41
  video_id = info['id']
42
- filename = f"{video_id}.wav"
 
 
 
 
 
 
43
  ydl.download([video_url])
44
- status_message.text("Audio downloaded and converted.")
45
-
46
- # Read the file and return its contents
47
  with open(filename, 'rb') as audio_file:
48
  audio_bytes = audio_file.read()
49
-
50
- # Clean up the temporary file
51
  os.remove(filename)
52
-
53
- return audio_bytes, 'wav'
54
  except Exception as e:
55
  st.error(f"Error during download or conversion: {e}")
56
- return None, None
 
 
 
 
 
 
57
 
58
- def aggregate_speech_segments(speech_timestamps, max_duration=30):
59
- """Aggregates speech segments into chunks with a maximum duration,
60
- merging the last segment if it's contained within the second-to-last.
61
 
62
  Args:
63
- speech_timestamps: A list of dictionaries, where each dictionary represents
64
- a speech segment with 'start' and 'end' timestamps
65
- (in seconds).
66
- max_duration: The maximum desired duration of each aggregated segment
67
- (in seconds). Defaults to 30.
 
68
 
69
  Returns:
70
- A list of dictionaries, where each dictionary represents an aggregated
71
- speech segment with 'start' and 'end' timestamps.
72
  """
73
-
74
- if not speech_timestamps:
75
- return []
76
-
77
- aggregated_segments = []
78
- current_segment_start = speech_timestamps[0]['start']
79
- current_segment_end = speech_timestamps[0]['end']
80
-
81
- for segment in speech_timestamps[1:]:
82
- if segment['start'] - current_segment_start >= max_duration:
83
- # Start a new segment if the current duration exceeds max_duration
84
- aggregated_segments.append({'start': current_segment_start, 'end': current_segment_end})
85
- current_segment_start = segment['start']
86
- current_segment_end = segment['end']
87
- else:
88
- # Extend the current segment
89
- current_segment_end = segment['end']
90
-
91
- # Add the last segment, checking for redundancy
92
- last_segment = {'start': current_segment_start, 'end': current_segment_end}
93
- if aggregated_segments:
94
- second_last_segment = aggregated_segments[-1]
95
- if last_segment['start'] >= second_last_segment['start'] and last_segment['end'] <= second_last_segment['end']:
96
- # Last segment is fully contained in the second-to-last, so don't add it
97
- pass
98
- else:
99
- aggregated_segments.append(last_segment)
100
- else:
101
- # If aggregated_segments is empty, add the last segment
102
- aggregated_segments.append(last_segment)
103
-
104
- return aggregated_segments
105
-
106
- @st.cache_data
107
- def split_audio_by_vad(audio_data: bytes, ext: str, _vad_model, sensitivity: float, return_seconds: bool = True):
108
  if not audio_data:
109
  st.error("No audio data received.")
110
  return []
111
-
112
  try:
113
  audio = pydub.AudioSegment.from_file(io.BytesIO(audio_data), format=ext)
114
-
115
- # VAD parameters
116
  rate = audio.frame_rate
 
 
 
 
 
 
117
  window_size_samples = int(512 + (1536 - 512) * (1 - sensitivity))
118
  speech_threshold = 0.5 + (0.95 - 0.5) * sensitivity
119
 
120
- # Convert audio to numpy array for VAD
121
  samples = np.array(audio.get_array_of_samples())
122
 
123
- # Get speech timestamps
124
  speech_timestamps = get_speech_timestamps(
125
- samples,
126
  _vad_model,
127
- sampling_rate=rate,
128
  return_seconds=return_seconds,
129
  window_size_samples=window_size_samples,
130
  threshold=speech_threshold,
@@ -134,43 +124,45 @@ def split_audio_by_vad(audio_data: bytes, ext: str, _vad_model, sensitivity: flo
134
  st.warning("No speech segments detected.")
135
  return []
136
 
137
- # rectify timestamps
138
  speech_timestamps[0]["start"] = 0.
139
  speech_timestamps[-1]['end'] = audio.duration_seconds
140
  for i, chunk in enumerate(speech_timestamps[1:], start=1):
141
- chunk["start"] = speech_timestamps[i-1]['end']
142
-
143
- # Aggregate segments into ~30 second chunks
144
- aggregated_segments = aggregate_speech_segments(speech_timestamps, max_duration=30)
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  if not aggregated_segments:
147
  return []
148
 
149
- # Create audio chunks based on timestamps
150
  chunks = []
151
  for segment in aggregated_segments:
152
  start_ms = int(segment['start'] * 1000)
153
  end_ms = int(segment['end'] * 1000)
154
  chunk = audio[start_ms:end_ms]
155
-
156
- # Export chunk to bytes
157
  chunk_io = io.BytesIO()
158
  chunk.export(chunk_io, format=ext)
159
- chunk_data = chunk_io.getvalue() # Get bytes directly
160
-
161
  chunks.append({
162
- 'data': chunk_data,
163
  'start': segment['start'],
164
  'end': segment['end']
165
  })
166
- chunk_io.close() # Close the BytesIO object after getting the value
167
-
168
  return chunks
169
  except Exception as e:
170
  st.error(f"Error processing audio in split_audio_by_vad: {str(e)}")
171
  return []
172
  finally:
173
- # Explicitly release pydub resources to prevent memory issues
174
  if 'audio' in locals():
175
  del audio
176
  if 'samples' in locals():
@@ -178,18 +170,28 @@ def split_audio_by_vad(audio_data: bytes, ext: str, _vad_model, sensitivity: flo
178
 
179
  @st.cache_data
180
  def transcribe_batch(batch, _transcriber, language=None):
 
 
 
 
 
 
 
 
 
 
 
181
  transcriptions = []
182
  for i, chunk_data in enumerate(batch):
183
  try:
184
  generate_kwargs = {
185
  "task": "transcribe",
186
- "return_timestamps": True
 
187
  }
188
- if language:
189
- generate_kwargs["language"] = language
190
-
191
  transcription = _transcriber(
192
- chunk_data['data'],
193
  generate_kwargs=generate_kwargs
194
  )
195
  transcriptions.append({
@@ -204,47 +206,93 @@ def transcribe_batch(batch, _transcriber, language=None):
204
 
205
  # --- Streamlit App ---
206
  def setup_ui():
 
207
  st.title("YouTube Video Transcriber")
208
- video_url = st.text_input("YouTube Video Link:")
209
- language = st.text_input("Language (two-letter code, e.g., 'en', 'es', leave empty for auto-detection):", max_chars=2)
210
- batch_size = st.number_input("Batch Size", min_value=1, max_value=10, value=2) # Batch size selection
211
- transcribe_button = st.button("Transcribe")
212
- return video_url, language,batch_size, transcribe_button
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  @st.cache_resource
215
  def initialize_models():
 
216
  device = "cuda" if torch.cuda.is_available() else "cpu"
217
  transcriber = load_transcriber(device)
218
  vad_model = load_vad_model()
219
  return transcriber, vad_model
220
 
221
- def process_transcription(video_url, vad_sensitivity, batch_size, transcriber, vad_model, language=None):
222
- transcription_output = st.empty()
223
- audio_data, ext = download_and_convert_audio(video_url)
 
 
 
 
 
 
 
 
 
 
 
 
224
  if not audio_data:
225
- return
226
-
227
- chunks = split_audio_by_vad(audio_data, ext, vad_model, vad_sensitivity)
228
  if not chunks:
229
- return
230
 
231
  total_chunks = len(chunks)
232
  transcriptions = []
 
233
  for i in range(0, total_chunks, batch_size):
234
  batch = chunks[i:i + batch_size]
235
  batch_transcriptions = transcribe_batch(batch, transcriber, language)
236
  transcriptions.extend(batch_transcriptions)
237
- display_transcription(transcriptions, transcription_output)
238
 
 
239
  st.success("Transcription complete!")
240
 
241
- def display_transcription(transcriptions, output_area):
242
  full_transcription = ""
243
  for chunk in transcriptions:
244
  start_time = format_seconds(chunk['start'])
245
  end_time = format_seconds(chunk['end'])
246
  full_transcription += f"[{start_time} - {end_time}]: {chunk['text'].strip()}\n\n"
247
- output_area.text_area("Transcription:", value=full_transcription, height=300, key=random.random())
 
248
 
249
  def format_seconds(seconds):
250
  """Formats seconds into HH:MM:SS string."""
@@ -252,14 +300,123 @@ def format_seconds(seconds):
252
  hours, minutes = divmod(minutes, 60)
253
  return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"
254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  transcriber, vad_model = initialize_models()
257
- video_url, language, batch_size, transcribe_button = setup_ui()
258
- if transcribe_button:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  if not video_url:
260
  st.error("Please enter a YouTube video link.")
261
  return
262
- process_transcription(video_url, VAD_SENSITIVITY, batch_size, transcriber, vad_model, language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  if __name__ == "__main__":
265
  main()
 
9
  import numpy as np
10
  import pydub
11
 
 
 
12
  # --- Model Loading and Caching ---
13
  @st.cache_resource
14
  def load_transcriber(_device):
15
+ """Loads the Whisper transcription model."""
16
  transcriber = pipeline(model="openai/whisper-large-v3-turbo", device=_device)
17
  return transcriber
18
 
19
  @st.cache_resource
20
  def load_vad_model():
21
+ """Loads the Silero VAD model."""
22
  return load_silero_vad()
23
 
24
  # --- Audio Processing Functions ---
25
  @st.cache_resource
26
+ def download_and_convert_audio(video_url, audio_format="wav"):
27
+ """Downloads and converts audio from a YouTube video.
28
+
29
+ Args:
30
+ video_url (str): The URL of the YouTube video.
31
+ audio_format (str): The desired audio format (e.g., "wav", "mp3").
32
+
33
+ Returns:
34
+ tuple: (audio_bytes, audio_format, info_dict) or (None, None, None) on error.
35
+ """
36
  status_message = st.empty()
37
  status_message.text("Downloading audio...")
38
  try:
39
  ydl_opts = {
40
+ 'format': f'bestaudio/best',
41
  'postprocessors': [{
42
  'key': 'FFmpegExtractAudio',
43
+ 'preferredcodec': audio_format,
 
44
  }],
45
  'outtmpl': '%(id)s.%(ext)s',
46
+ 'noplaylist': True,
47
+ 'progress_hooks': [lambda d: update_download_progress(d, status_message)],
48
  }
49
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
50
  info = ydl.extract_info(video_url, download=False)
51
+ if 'entries' in info:
52
+ info = info['entries'][0]
53
  video_id = info['id']
54
+ filename = f"{video_id}.{audio_format}"
55
+
56
+ audio_formats = [f for f in info.get('formats', []) if f.get('acodec') != 'none' and f.get('vcodec') == 'none']
57
+ if not audio_formats:
58
+ st.warning(f"No audio-only format found. Downloading and converting from best video format to {audio_format}.")
59
+ ydl_opts['format'] = 'best'
60
+
61
  ydl.download([video_url])
62
+ status_message.text(f"Audio downloaded and converted to {audio_format}.")
63
+
 
64
  with open(filename, 'rb') as audio_file:
65
  audio_bytes = audio_file.read()
66
+
 
67
  os.remove(filename)
68
+ return audio_bytes, audio_format, info
 
69
  except Exception as e:
70
  st.error(f"Error during download or conversion: {e}")
71
+ return None, None, None
72
+
73
+ def update_download_progress(d, status_message):
74
+ """Updates the download progress in the Streamlit UI."""
75
+ if d['status'] == 'downloading':
76
+ p = round(d['downloaded_bytes'] / d['total_bytes'] * 100)
77
+ status_message.text(f"Downloading: {p}%")
78
 
79
+ @st.cache_data
80
+ def split_audio_by_vad(audio_data: bytes, ext: str, _vad_model, sensitivity: float, max_duration: int = 30, return_seconds: bool = True):
81
+ """Splits audio into chunks based on voice activity detection (VAD).
82
 
83
  Args:
84
+ audio_data (bytes): The audio data as bytes.
85
+ ext (str): The audio file extension.
86
+ _vad_model: The VAD model.
87
+ sensitivity (float): The VAD sensitivity (0.0 to 1.0).
88
+ max_duration (int): The maximum duration of each chunk in seconds.
89
+ return_seconds (bool): Whether to return timestamps in seconds.
90
 
91
  Returns:
92
+ list: A list of dictionaries, where each dictionary represents an audio chunk.
93
+ Returns an empty list if no speech segments are detected or an error occurs.
94
  """
95
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  if not audio_data:
97
  st.error("No audio data received.")
98
  return []
99
+
100
  try:
101
  audio = pydub.AudioSegment.from_file(io.BytesIO(audio_data), format=ext)
 
 
102
  rate = audio.frame_rate
103
+
104
+ # Convert to mono if stereo for compatibility with VAD
105
+ if audio.channels > 1:
106
+ audio = audio.set_channels(1)
107
+
108
+ # Calculate dynamic VAD parameters based on sensitivity
109
  window_size_samples = int(512 + (1536 - 512) * (1 - sensitivity))
110
  speech_threshold = 0.5 + (0.95 - 0.5) * sensitivity
111
 
 
112
  samples = np.array(audio.get_array_of_samples())
113
 
 
114
  speech_timestamps = get_speech_timestamps(
115
+ samples,
116
  _vad_model,
117
+ sampling_rate=rate,
118
  return_seconds=return_seconds,
119
  window_size_samples=window_size_samples,
120
  threshold=speech_threshold,
 
124
  st.warning("No speech segments detected.")
125
  return []
126
 
 
127
  speech_timestamps[0]["start"] = 0.
128
  speech_timestamps[-1]['end'] = audio.duration_seconds
129
  for i, chunk in enumerate(speech_timestamps[1:], start=1):
130
+ chunk["start"] = speech_timestamps[i - 1]['end']
 
 
 
131
 
132
+ aggregated_segments = []
133
+ if speech_timestamps:
134
+ current_segment_start = speech_timestamps[0]['start']
135
+ current_segment_end = speech_timestamps[0]['end']
136
+ for segment in speech_timestamps[1:]:
137
+ if segment['start'] - current_segment_start >= max_duration:
138
+ aggregated_segments.append({'start': current_segment_start, 'end': current_segment_end})
139
+ current_segment_start = segment['start']
140
+ current_segment_end = segment['end']
141
+ else:
142
+ current_segment_end = segment['end']
143
+ aggregated_segments.append({'start': current_segment_start, 'end': current_segment_end})
144
+
145
  if not aggregated_segments:
146
  return []
147
 
 
148
  chunks = []
149
  for segment in aggregated_segments:
150
  start_ms = int(segment['start'] * 1000)
151
  end_ms = int(segment['end'] * 1000)
152
  chunk = audio[start_ms:end_ms]
 
 
153
  chunk_io = io.BytesIO()
154
  chunk.export(chunk_io, format=ext)
 
 
155
  chunks.append({
156
+ 'data': chunk_io.getvalue(),
157
  'start': segment['start'],
158
  'end': segment['end']
159
  })
160
+ chunk_io.close()
 
161
  return chunks
162
  except Exception as e:
163
  st.error(f"Error processing audio in split_audio_by_vad: {str(e)}")
164
  return []
165
  finally:
 
166
  if 'audio' in locals():
167
  del audio
168
  if 'samples' in locals():
 
170
 
171
  @st.cache_data
172
  def transcribe_batch(batch, _transcriber, language=None):
173
+ """Transcribes a batch of audio chunks.
174
+
175
+ Args:
176
+ batch (list): A list of audio chunk dictionaries.
177
+ _transcriber: The transcription model.
178
+ language (str, optional): The language of the audio (e.g., "en", "es"). Defaults to None (auto-detection).
179
+
180
+ Returns:
181
+ list: A list of dictionaries, each containing the transcription, start, and end time of a chunk.
182
+ Returns an empty list if an error occurs.
183
+ """
184
  transcriptions = []
185
  for i, chunk_data in enumerate(batch):
186
  try:
187
  generate_kwargs = {
188
  "task": "transcribe",
189
+ "return_timestamps": True,
190
+ "language": language
191
  }
192
+
 
 
193
  transcription = _transcriber(
194
+ chunk_data['data'],
195
  generate_kwargs=generate_kwargs
196
  )
197
  transcriptions.append({
 
206
 
207
  # --- Streamlit App ---
208
  def setup_ui():
209
+ """Sets up the Streamlit user interface."""
210
  st.title("YouTube Video Transcriber")
211
+
212
+ col1, col2, col3, col4 = st.columns(4)
213
+ with col1:
214
+ transcribe_option = st.checkbox("Transcribe", value=True)
215
+ with col2:
216
+ download_audio_option = st.checkbox("Download Audio", value=False)
217
+ with col3:
218
+ download_video_option = st.checkbox("Download Video", value=False)
219
+ with col4:
220
+ pass
221
+
222
+ video_url = st.text_input("YouTube Video Link:", key="video_url")
223
+ language = st.text_input("Language (two-letter code, e.g., 'en', 'es', leave empty for auto-detection):", max_chars=2, key="language")
224
+ batch_size = st.number_input("Batch Size", min_value=1, value=2, key="batch_size")
225
+ vad_sensitivity = st.slider("VAD Sensitivity", min_value=0.0, max_value=1.0, value=0.1, step=0.05, key="vad_sensitivity")
226
+
227
+ # Use session state to manage audio format selection and reset
228
+ if 'reset_audio_format' not in st.session_state:
229
+ st.session_state.reset_audio_format = False
230
+
231
+ if 'audio_format' not in st.session_state or st.session_state.reset_audio_format:
232
+ st.session_state.audio_format = "wav" # Default value
233
+ st.session_state.reset_audio_format = False
234
+
235
+ audio_format = st.selectbox("Audio Format", ["wav", "mp3", "ogg", "flac"], key="audio_format_widget", index=["wav", "mp3", "ogg", "flac"].index(st.session_state.audio_format))
236
+ st.session_state.audio_format = audio_format
237
+
238
+ if download_video_option:
239
+ video_format = st.selectbox("Video Format", ["mp4", "webm"], index=0, key="video_format")
240
+ else:
241
+ video_format = "mp4"
242
+
243
+ process_button = st.button("Process")
244
+
245
+ return video_url, language, batch_size, transcribe_option, download_audio_option, download_video_option, process_button, vad_sensitivity, audio_format, video_format
246
 
247
  @st.cache_resource
248
  def initialize_models():
249
+ """Initializes the transcription and VAD models."""
250
  device = "cuda" if torch.cuda.is_available() else "cpu"
251
  transcriber = load_transcriber(device)
252
  vad_model = load_vad_model()
253
  return transcriber, vad_model
254
 
255
+ def process_transcription(video_url, vad_sensitivity, batch_size, transcriber, vad_model, audio_format, language=None):
256
+ """Downloads, processes, and transcribes the audio from a YouTube video.
257
+
258
+ Args:
259
+ video_url (str): The URL of the YouTube video.
260
+ vad_sensitivity (float): The VAD sensitivity.
261
+ batch_size (int): The batch size for transcription.
262
+ transcriber: The transcription model.
263
+ vad_model: The VAD model.
264
+ language (str, optional): The language of the audio. Defaults to None.
265
+
266
+ Returns:
267
+ tuple: (full_transcription, audio_data, audio_format, info) or (None, None, None, None) on error.
268
+ """
269
+ audio_data, audio_format, info = download_and_convert_audio(video_url, audio_format)
270
  if not audio_data:
271
+ return None, None, None, None
272
+
273
+ chunks = split_audio_by_vad(audio_data, audio_format, vad_model, vad_sensitivity)
274
  if not chunks:
275
+ return None, None, None, None
276
 
277
  total_chunks = len(chunks)
278
  transcriptions = []
279
+ progress_bar = st.progress(0)
280
  for i in range(0, total_chunks, batch_size):
281
  batch = chunks[i:i + batch_size]
282
  batch_transcriptions = transcribe_batch(batch, transcriber, language)
283
  transcriptions.extend(batch_transcriptions)
284
+ progress_bar.progress((i + len(batch)) / total_chunks)
285
 
286
+ progress_bar.empty()
287
  st.success("Transcription complete!")
288
 
 
289
  full_transcription = ""
290
  for chunk in transcriptions:
291
  start_time = format_seconds(chunk['start'])
292
  end_time = format_seconds(chunk['end'])
293
  full_transcription += f"[{start_time} - {end_time}]: {chunk['text'].strip()}\n\n"
294
+
295
+ return full_transcription, audio_data, audio_format, info
296
 
297
  def format_seconds(seconds):
298
  """Formats seconds into HH:MM:SS string."""
 
300
  hours, minutes = divmod(minutes, 60)
301
  return f"{int(hours):02}:{int(minutes):02}:{int(seconds):02}"
302
 
303
+ def download_video(video_url, video_format):
304
+ """Downloads video from YouTube using yt-dlp."""
305
+ status_message = st.empty()
306
+ status_message.text("Downloading video...")
307
+ try:
308
+ ydl_opts = {
309
+ 'format': f'bestvideo[ext={video_format}]+bestaudio[ext=m4a]/best[ext={video_format}]/best',
310
+ 'outtmpl': '%(title)s.%(ext)s',
311
+ 'noplaylist': True,
312
+ 'progress_hooks': [lambda d: update_download_progress(d, status_message)],
313
+ }
314
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
315
+ info_dict = ydl.extract_info(video_url, download=True)
316
+ video_filename = ydl.prepare_filename(info_dict)
317
+ video_title = info_dict.get("title", "video")
318
+ status_message.text(f"Video downloaded: {video_title}")
319
+
320
+ with open(video_filename, 'rb') as video_file:
321
+ video_bytes = video_file.read()
322
+
323
+ os.remove(video_filename)
324
+
325
+ return video_bytes, video_filename, info_dict
326
+ except Exception as e:
327
+ st.error(f"Error during video download: {e}")
328
+ return None, None, None
329
+
330
+ import random
331
+ import streamlit as st
332
+ import io
333
+ import os
334
+ from transformers import pipeline
335
+ import torch
336
+ import yt_dlp
337
+ from silero_vad import load_silero_vad, get_speech_timestamps
338
+ import numpy as np
339
+ import pydub
340
+
341
+ # ... (rest of your code, including model loading, audio functions, etc.)
342
+
343
  def main():
344
+ """Main function to run the Streamlit application."""
345
+
346
+ # Initialize session state variables
347
+ if 'full_transcription' not in st.session_state:
348
+ st.session_state.full_transcription = None
349
+ if 'audio_data' not in st.session_state:
350
+ st.session_state.audio_data = None
351
+ if 'info' not in st.session_state:
352
+ st.session_state.info = None
353
+ if 'video_data' not in st.session_state:
354
+ st.session_state.video_data = None
355
+ if 'video_filename' not in st.session_state:
356
+ st.session_state.video_filename = None
357
+
358
  transcriber, vad_model = initialize_models()
359
+
360
+ # Call setup_ui() to get UI element values
361
+ video_url, language, batch_size, transcribe_option, download_audio_option, download_video_option, process_button, vad_sensitivity, audio_format, video_format = setup_ui()
362
+
363
+ transcription_output = st.empty()
364
+ if st.session_state.full_transcription:
365
+ transcription_output.text_area("Transcription:", value=st.session_state.full_transcription, height=300, key=random.random())
366
+
367
+ if process_button:
368
+ st.session_state.full_transcription = None
369
+ st.session_state.audio_data = None
370
+ st.session_state.info = None
371
+ st.session_state.video_data = None
372
+ st.session_state.video_filename = None
373
+ st.session_state.reset_audio_format = True
374
+
375
  if not video_url:
376
  st.error("Please enter a YouTube video link.")
377
  return
378
+
379
+ if transcribe_option:
380
+ st.session_state.full_transcription, st.session_state.audio_data, st.session_state.audio_format, st.session_state.info = process_transcription(video_url, vad_sensitivity, batch_size, transcriber, vad_model, audio_format, language)
381
+ if st.session_state.full_transcription:
382
+ transcription_output.text_area("Transcription:", value=st.session_state.full_transcription, height=300, key=random.random())
383
+
384
+ if download_audio_option:
385
+ if st.session_state.audio_data is None or st.session_state.audio_format is None or st.session_state.info is None:
386
+ st.session_state.audio_data, st.session_state.audio_format, st.session_state.info = download_and_convert_audio(video_url, audio_format)
387
+
388
+ if download_video_option:
389
+ st.session_state.video_data, st.session_state.video_filename, st.session_state.info = download_video(video_url, video_format)
390
+
391
+ # Download button logic (moved after setup_ui() call)
392
+ col1, col2, col3 = st.columns(3)
393
+ with col1:
394
+ if st.session_state.full_transcription and transcribe_option:
395
+ st.download_button(
396
+ label="Download Transcription (TXT)",
397
+ data=st.session_state.full_transcription,
398
+ file_name=f"{st.session_state.info['id']}_transcription.txt",
399
+ mime="text/plain"
400
+ )
401
+
402
+ with col2:
403
+ # Now download_audio_option is defined
404
+ if st.session_state.audio_data is not None and download_audio_option:
405
+ st.download_button(
406
+ label=f"Download Audio ({st.session_state.audio_format})",
407
+ data=st.session_state.audio_data,
408
+ file_name=f"{st.session_state.info['id']}.{st.session_state.audio_format}",
409
+ mime=f"audio/{st.session_state.audio_format}"
410
+ )
411
+
412
+ with col3:
413
+ if st.session_state.video_data is not None and download_video_option:
414
+ st.download_button(
415
+ label="Download Video",
416
+ data=st.session_state.video_data,
417
+ file_name=st.session_state.video_filename,
418
+ mime=f"video/{video_format}"
419
+ )
420
 
421
  if __name__ == "__main__":
422
  main()