Spaces:
Running
Running
Commit
·
1e7df51
1
Parent(s):
6aa61fa
Script fixes
Browse files
app.py
CHANGED
@@ -18,6 +18,7 @@ from pytube import YouTube
|
|
18 |
from sklearn.preprocessing import StandardScaler
|
19 |
import shutil
|
20 |
import streamlit as st
|
|
|
21 |
|
22 |
|
23 |
# Constants
|
@@ -29,26 +30,27 @@ N_FEATURES = 15
|
|
29 |
MODEL_PATH = "models/CRNN/best_model_V3.h5"
|
30 |
AUDIO_TEMP_PATH = "output/temp"
|
31 |
|
32 |
-
|
|
|
33 |
try:
|
34 |
yt = YouTube(url)
|
35 |
video_title = yt.title
|
36 |
audio_stream = yt.streams.filter(only_audio=True).first()
|
37 |
if audio_stream:
|
38 |
-
|
39 |
-
out_file = audio_stream.download(
|
40 |
base, _ = os.path.splitext(out_file)
|
41 |
audio_file = base + '.mp3'
|
42 |
if os.path.exists(audio_file):
|
43 |
os.remove(audio_file)
|
44 |
os.rename(out_file, audio_file)
|
45 |
-
return audio_file, video_title
|
46 |
else:
|
47 |
st.error("No audio stream found")
|
48 |
-
return None, None
|
49 |
except Exception as e:
|
50 |
st.error(f"An error occurred: {e}")
|
51 |
-
return None, None
|
52 |
|
53 |
|
54 |
def strip_silence(audio_path):
|
@@ -412,38 +414,24 @@ def make_predictions(model, processed_audio, audio_features, url, video_name):
|
|
412 |
audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
|
413 |
chorus_start_times = [meter_grid_times[i] for i in range(len(
|
414 |
smoothed_predictions)) if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)]
|
|
|
|
|
415 |
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
f"
|
421 |
-
|
422 |
-
print()
|
423 |
-
print()
|
424 |
-
print(header_footer)
|
425 |
-
print(f"{video_name.center(max_length + 2)}")
|
426 |
-
print(f"Number of choruses identified: {len(chorus_start_times)}".center(
|
427 |
-
max_length + 4))
|
428 |
-
print(header_footer)
|
429 |
-
for link in youtube_links:
|
430 |
-
print(link)
|
431 |
-
print(header_footer)
|
432 |
|
433 |
if len(chorus_start_times) == 0:
|
434 |
-
|
435 |
|
436 |
return smoothed_predictions
|
437 |
|
438 |
|
439 |
def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
|
440 |
-
"""
|
441 |
-
Draw meter grid lines on the plot.
|
442 |
-
|
443 |
-
Parameters:
|
444 |
-
- ax (plt.Axes): The matplotlib axes object to draw on.
|
445 |
-
- meter_grid_times (np.ndarray): Array of times at which to draw the meter lines.
|
446 |
-
"""
|
447 |
for time in meter_grid_times:
|
448 |
ax.axvline(x=time, color='grey', linestyle='--',
|
449 |
linewidth=1, alpha=0.6)
|
@@ -499,17 +487,16 @@ def main():
|
|
499 |
url = st.text_input("YouTube URL")
|
500 |
if st.button("Find Chorus"):
|
501 |
if url:
|
502 |
-
audio_file, video_title = extract_audio(url)
|
503 |
if audio_file:
|
504 |
strip_silence(audio_file)
|
505 |
-
processed_audio, audio_features = process_audio(audio_path=
|
506 |
model = load_model()
|
507 |
smoothed_predictions = make_predictions(model, processed_audio, audio_features, url, video_title)
|
508 |
plot_predictions(audio_features=audio_features, predictions=smoothed_predictions)
|
509 |
-
shutil.rmtree(
|
510 |
else:
|
511 |
st.error("Please enter a valid YouTube URL")
|
512 |
|
513 |
if __name__ == "__main__":
|
514 |
main()
|
515 |
-
|
|
|
18 |
from sklearn.preprocessing import StandardScaler
|
19 |
import shutil
|
20 |
import streamlit as st
|
21 |
+
import tempfile
|
22 |
|
23 |
|
24 |
# Constants
|
|
|
30 |
MODEL_PATH = "models/CRNN/best_model_V3.h5"
|
31 |
AUDIO_TEMP_PATH = "output/temp"
|
32 |
|
33 |
+
|
34 |
+
def extract_audio(url):
|
35 |
try:
|
36 |
yt = YouTube(url)
|
37 |
video_title = yt.title
|
38 |
audio_stream = yt.streams.filter(only_audio=True).first()
|
39 |
if audio_stream:
|
40 |
+
temp_dir = tempfile.mkdtemp()
|
41 |
+
out_file = audio_stream.download(temp_dir)
|
42 |
base, _ = os.path.splitext(out_file)
|
43 |
audio_file = base + '.mp3'
|
44 |
if os.path.exists(audio_file):
|
45 |
os.remove(audio_file)
|
46 |
os.rename(out_file, audio_file)
|
47 |
+
return audio_file, video_title, temp_dir
|
48 |
else:
|
49 |
st.error("No audio stream found")
|
50 |
+
return None, None, None
|
51 |
except Exception as e:
|
52 |
st.error(f"An error occurred: {e}")
|
53 |
+
return None, None, None
|
54 |
|
55 |
|
56 |
def strip_silence(audio_path):
|
|
|
414 |
audio_features.meter_grid, sr=audio_features.sr, hop_length=audio_features.hop_length)
|
415 |
chorus_start_times = [meter_grid_times[i] for i in range(len(
|
416 |
smoothed_predictions)) if smoothed_predictions[i] == 1 and (i == 0 or smoothed_predictions[i - 1] == 0)]
|
417 |
+
chorus_end_times = [meter_grid_times[i + 1] for i in range(len(
|
418 |
+
smoothed_predictions)) if smoothed_predictions[i] == 1 and (i == len(smoothed_predictions) - 1 or smoothed_predictions[i + 1] == 0)]
|
419 |
|
420 |
+
st.write(f"**Video Title:** {video_name}")
|
421 |
+
st.write(f"**Number of choruses identified:** {len(chorus_start_times)}")
|
422 |
+
|
423 |
+
for start_time, end_time in zip(chorus_start_times, chorus_end_times):
|
424 |
+
link = f"{url}&t={int(start_time)}s"
|
425 |
+
st.write(f"Chorus from {start_time:.2f}s to {end_time:.2f}s: [Link]({link})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
if len(chorus_start_times) == 0:
|
428 |
+
st.write("No choruses identified.")
|
429 |
|
430 |
return smoothed_predictions
|
431 |
|
432 |
|
433 |
def plot_meter_lines(ax: plt.Axes, meter_grid_times: np.ndarray) -> None:
|
434 |
+
"""Draw meter grid lines on the plot."""
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
for time in meter_grid_times:
|
436 |
ax.axvline(x=time, color='grey', linestyle='--',
|
437 |
linewidth=1, alpha=0.6)
|
|
|
487 |
url = st.text_input("YouTube URL")
|
488 |
if st.button("Find Chorus"):
|
489 |
if url:
|
490 |
+
audio_file, video_title, temp_dir = extract_audio(url)
|
491 |
if audio_file:
|
492 |
strip_silence(audio_file)
|
493 |
+
processed_audio, audio_features = process_audio(audio_path=audio_file)
|
494 |
model = load_model()
|
495 |
smoothed_predictions = make_predictions(model, processed_audio, audio_features, url, video_title)
|
496 |
plot_predictions(audio_features=audio_features, predictions=smoothed_predictions)
|
497 |
+
shutil.rmtree(temp_dir)
|
498 |
else:
|
499 |
st.error("Please enter a valid YouTube URL")
|
500 |
|
501 |
if __name__ == "__main__":
|
502 |
main()
|
|