dennisvdang commited on
Commit
1e7df51
·
1 Parent(s): 6aa61fa

Script fixes

Browse files
Files changed (1) hide show
  1. app.py +21 -34
app.py CHANGED
@@ -18,6 +18,7 @@ from pytube import YouTube
18
  from sklearn.preprocessing import StandardScaler
19
  import shutil
20
  import streamlit as st
 
21
 
22
 
23
  # Constants
@@ -29,26 +30,27 @@ N_FEATURES = 15
29
  MODEL_PATH = "models/CRNN/best_model_V3.h5"
30
  AUDIO_TEMP_PATH = "output/temp"
31
 
32
- def extract_audio(url, output_path=AUDIO_TEMP_PATH):
 
33
  try:
34
  yt = YouTube(url)
35
  video_title = yt.title
36
  audio_stream = yt.streams.filter(only_audio=True).first()
37
  if audio_stream:
38
- os.makedirs(output_path, exist_ok=True)
39
- out_file = audio_stream.download(output_path)
40
  base, _ = os.path.splitext(out_file)
41
  audio_file = base + '.mp3'
42
  if os.path.exists(audio_file):
43
  os.remove(audio_file)
44
  os.rename(out_file, audio_file)
45
- return audio_file, video_title
46
  else:
47
  st.error("No audio stream found")
48
- return None, None
49
  except Exception as e:
50
  st.error(f"An error occurred: {e}")
51
- return None, None
52
 
53
 
54
  def strip_silence(audio_path):
@@ -412,38 +414,24 @@ def make_predictions(model, processed_audio, audio_features, url, video_name):
412
  audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
413
  chorus_start_times = [meter_grid_times[i] for i in range(len(
414
  smoothed_predictions)) if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)]
 
 
415
 
416
- youtube_links = [
417
- f"\033]8;;{url}&t={int(start_time)}s\033\\{url}&t={int(start_time)}s\033]8;;\033\\" for start_time in chorus_start_times
418
- ]
419
- max_length = max([len(link) for link in youtube_links] + [len(video_name), len(
420
- f"Number of choruses identified: {len(chorus_start_times)}")] if chorus_start_times else [0])
421
- header_footer = "=" * (max_length + 4)
422
- print()
423
- print()
424
- print(header_footer)
425
- print(f"{video_name.center(max_length + 2)}")
426
- print(f"Number of choruses identified: {len(chorus_start_times)}".center(
427
- max_length + 4))
428
- print(header_footer)
429
- for link in youtube_links:
430
- print(link)
431
- print(header_footer)
432
 
433
  if len(chorus_start_times) == 0:
434
- print("No choruses identified.")
435
 
436
  return smoothed_predictions
437
 
438
 
439
  def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
440
- """
441
- Draw meter grid lines on the plot.
442
-
443
- Parameters:
444
- - ax (plt.Axes): The matplotlib axes object to draw on.
445
- - meter_grid_times (np.ndarray): Array of times at which to draw the meter lines.
446
- """
447
  for time in meter_grid_times:
448
  ax.axvline(x=time, color='grey', linestyle='--',
449
  linewidth=1, alpha=0.6)
@@ -499,17 +487,16 @@ def main():
499
  url = st.text_input("YouTube URL")
500
  if st.button("Find Chorus"):
501
  if url:
502
- audio_file, video_title = extract_audio(url)
503
  if audio_file:
504
  strip_silence(audio_file)
505
- processed_audio, audio_features = process_audio(audio_path=AUDIO_TEMP_PATH)
506
  model = load_model()
507
  smoothed_predictions = make_predictions(model, processed_audio, audio_features, url, video_title)
508
  plot_predictions(audio_features=audio_features, predictions=smoothed_predictions)
509
- shutil.rmtree(AUDIO_TEMP_PATH)
510
  else:
511
  st.error("Please enter a valid YouTube URL")
512
 
513
  if __name__ == "__main__":
514
  main()
515
-
 
18
  from sklearn.preprocessing import StandardScaler
19
  import shutil
20
  import streamlit as st
21
+ import tempfile
22
 
23
 
24
  # Constants
 
30
  MODEL_PATH = "models/CRNN/best_model_V3.h5"
31
  AUDIO_TEMP_PATH = "output/temp"
32
 
33
+
34
+ def extract_audio(url):
35
  try:
36
  yt = YouTube(url)
37
  video_title = yt.title
38
  audio_stream = yt.streams.filter(only_audio=True).first()
39
  if audio_stream:
40
+ temp_dir = tempfile.mkdtemp()
41
+ out_file = audio_stream.download(temp_dir)
42
  base, _ = os.path.splitext(out_file)
43
  audio_file = base + '.mp3'
44
  if os.path.exists(audio_file):
45
  os.remove(audio_file)
46
  os.rename(out_file, audio_file)
47
+ return audio_file, video_title, temp_dir
48
  else:
49
  st.error("No audio stream found")
50
+ return None, None, None
51
  except Exception as e:
52
  st.error(f"An error occurred: {e}")
53
+ return None, None, None
54
 
55
 
56
  def strip_silence(audio_path):
 
414
  audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
415
  chorus_start_times = [meter_grid_times[i] for i in range(len(
416
  smoothed_predictions)) if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)]
417
+ chorus_end_times = [meter_grid_times[i + 1] for i in range(len(
418
+ smoothed_predictions)) if smoothed_predictions[i] == 1 and (i == len(smoothed_predictions) - 1 or smoothed_predictions[i + 1] == 0)]
419
 
420
+ st.write(f"**Video Title:** {video_name}")
421
+ st.write(f"**Number of choruses identified:** {len(chorus_start_times)}")
422
+
423
+ for start_time, end_time in zip(chorus_start_times, chorus_end_times):
424
+ link = f"{url}&t={int(start_time)}s"
425
+ st.write(f"Chorus from {start_time:.2f}s to {end_time:.2f}s: [Link]({link})")
 
 
 
 
 
 
 
 
 
 
426
 
427
  if len(chorus_start_times) == 0:
428
+ st.write("No choruses identified.")
429
 
430
  return smoothed_predictions
431
 
432
 
433
  def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
434
+ """Draw meter grid lines on the plot."""
 
 
 
 
 
 
435
  for time in meter_grid_times:
436
  ax.axvline(x=time, color='grey', linestyle='--',
437
  linewidth=1, alpha=0.6)
 
487
  url = st.text_input("YouTube URL")
488
  if st.button("Find Chorus"):
489
  if url:
490
+ audio_file, video_title, temp_dir = extract_audio(url)
491
  if audio_file:
492
  strip_silence(audio_file)
493
+ processed_audio, audio_features = process_audio(audio_path=audio_file)
494
  model = load_model()
495
  smoothed_predictions = make_predictions(model, processed_audio, audio_features, url, video_title)
496
  plot_predictions(audio_features=audio_features, predictions=smoothed_predictions)
497
+ shutil.rmtree(temp_dir)
498
  else:
499
  st.error("Please enter a valid YouTube URL")
500
 
501
  if __name__ == "__main__":
502
  main()