Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -6,11 +6,12 @@ import numpy as np
|
|
6 |
import plotly.graph_objects as go
|
7 |
import tempfile
|
8 |
import os
|
|
|
9 |
|
10 |
# Set page config
|
11 |
st.set_page_config(page_title="π΅ Music Genre Classifier", layout="wide")
|
12 |
|
13 |
-
# Custom CSS
|
14 |
st.markdown("""
|
15 |
<style>
|
16 |
.main-title {
|
@@ -51,28 +52,35 @@ def load_model():
|
|
51 |
|
52 |
pipe = load_model()
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
def classify_audio(audio_file):
|
|
|
55 |
start_time = time.time()
|
56 |
-
|
57 |
-
#
|
58 |
-
|
59 |
-
tmp_file.write(audio_file.getvalue())
|
60 |
-
tmp_file_path = tmp_file.name
|
61 |
|
62 |
try:
|
63 |
-
|
64 |
-
preds = pipe(
|
65 |
outputs = {p["label"]: p["score"] for p in preds}
|
66 |
end_time = time.time()
|
67 |
prediction_time = end_time - start_time
|
68 |
-
return outputs, prediction_time
|
69 |
finally:
|
70 |
-
#
|
71 |
-
os.unlink(tmp_file_path)
|
72 |
|
|
|
73 |
st.markdown("<h1 class='main-title'>π΅ Music Genre Classifier</h1>", unsafe_allow_html=True)
|
74 |
st.markdown("<p class='sub-title'>Upload a music file and let AI detect its genre!</p>", unsafe_allow_html=True)
|
75 |
|
|
|
76 |
st.sidebar.title("About")
|
77 |
st.sidebar.info("""
|
78 |
This app uses a fine-tuned wav2vec2-base model to classify music genres.
|
@@ -80,23 +88,27 @@ Model: juangtzi/wav2vec2-base-finetuned-gtzan
|
|
80 |
Dataset: GTZAN
|
81 |
""")
|
82 |
|
|
|
83 |
uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"])
|
84 |
|
85 |
if uploaded_file is not None:
|
|
|
86 |
st.audio(uploaded_file)
|
87 |
|
|
|
88 |
if st.button("Classify Genre"):
|
89 |
with st.spinner("Analyzing the music... π§"):
|
90 |
try:
|
91 |
-
results, pred_time
|
92 |
|
93 |
-
# Get top genre
|
94 |
top_genre = max(results, key=results.get)
|
95 |
|
|
|
96 |
st.markdown(f"<h2 class='genre-result'>Detected Genre: {top_genre.capitalize()}</h2>", unsafe_allow_html=True)
|
97 |
st.markdown(f"<p class='prediction-time'>Prediction Time: {pred_time:.2f} seconds</p>", unsafe_allow_html=True)
|
98 |
|
99 |
-
#
|
100 |
fig = go.Figure(data=[go.Bar(
|
101 |
x=list(results.keys()),
|
102 |
y=list(results.values()),
|
@@ -111,7 +123,10 @@ if uploaded_file is not None:
|
|
111 |
)
|
112 |
st.plotly_chart(fig, use_container_width=True)
|
113 |
|
114 |
-
#
|
|
|
|
|
|
|
115 |
st.subheader("Audio Waveform")
|
116 |
fig_waveform = go.Figure(data=[go.Scatter(y=y, mode='lines', line=dict(color='#1DB954'))])
|
117 |
fig_waveform.update_layout(
|
@@ -127,8 +142,10 @@ if uploaded_file is not None:
|
|
127 |
st.error(f"An error occurred while processing the audio: {str(e)}")
|
128 |
st.info("Please try uploading the file again or use a different audio file.")
|
129 |
|
|
|
130 |
st.markdown("""
|
131 |
<div style='text-align: center; margin-top: 2rem;'>
|
132 |
<p>Created with β€οΈ by AI. Powered by Streamlit and Hugging Face Transformers.</p>
|
133 |
</div>
|
134 |
-
""", unsafe_allow_html=True)
|
|
|
|
6 |
import plotly.graph_objects as go
|
7 |
import tempfile
|
8 |
import os
|
9 |
+
import soundfile as sf
|
10 |
|
11 |
# Set page config
|
12 |
st.set_page_config(page_title="π΅ Music Genre Classifier", layout="wide")
|
13 |
|
14 |
+
# Custom CSS for UI
|
15 |
st.markdown("""
|
16 |
<style>
|
17 |
.main-title {
|
|
|
52 |
|
53 |
pipe = load_model()
|
54 |
|
55 |
+
def convert_to_wav(audio_file):
|
56 |
+
"""Converts uploaded audio file to WAV format."""
|
57 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav:
|
58 |
+
y, sr = librosa.load(audio_file, sr=None)
|
59 |
+
sf.write(tmp_wav.name, y, sr)
|
60 |
+
return tmp_wav.name
|
61 |
+
|
62 |
def classify_audio(audio_file):
|
63 |
+
"""Classifies the audio file using the loaded model."""
|
64 |
start_time = time.time()
|
65 |
+
|
66 |
+
# Convert to WAV format before passing to the model
|
67 |
+
wav_file = convert_to_wav(audio_file)
|
|
|
|
|
68 |
|
69 |
try:
|
70 |
+
# Use the wav file with the model
|
71 |
+
preds = pipe(wav_file)
|
72 |
outputs = {p["label"]: p["score"] for p in preds}
|
73 |
end_time = time.time()
|
74 |
prediction_time = end_time - start_time
|
75 |
+
return outputs, prediction_time
|
76 |
finally:
|
77 |
+
os.unlink(wav_file) # Remove the temp file
|
|
|
78 |
|
79 |
+
# Page title and subtitle
|
80 |
st.markdown("<h1 class='main-title'>π΅ Music Genre Classifier</h1>", unsafe_allow_html=True)
|
81 |
st.markdown("<p class='sub-title'>Upload a music file and let AI detect its genre!</p>", unsafe_allow_html=True)
|
82 |
|
83 |
+
# Sidebar with model and dataset information
|
84 |
st.sidebar.title("About")
|
85 |
st.sidebar.info("""
|
86 |
This app uses a fine-tuned wav2vec2-base model to classify music genres.
|
|
|
88 |
Dataset: GTZAN
|
89 |
""")
|
90 |
|
91 |
+
# Upload file section
|
92 |
uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"])
|
93 |
|
94 |
if uploaded_file is not None:
|
95 |
+
# Display the uploaded audio file
|
96 |
st.audio(uploaded_file)
|
97 |
|
98 |
+
# Classify the uploaded audio
|
99 |
if st.button("Classify Genre"):
|
100 |
with st.spinner("Analyzing the music... π§"):
|
101 |
try:
|
102 |
+
results, pred_time = classify_audio(uploaded_file)
|
103 |
|
104 |
+
# Get the top predicted genre
|
105 |
top_genre = max(results, key=results.get)
|
106 |
|
107 |
+
# Display the top predicted genre
|
108 |
st.markdown(f"<h2 class='genre-result'>Detected Genre: {top_genre.capitalize()}</h2>", unsafe_allow_html=True)
|
109 |
st.markdown(f"<p class='prediction-time'>Prediction Time: {pred_time:.2f} seconds</p>", unsafe_allow_html=True)
|
110 |
|
111 |
+
# Plot the genre probabilities as a bar chart
|
112 |
fig = go.Figure(data=[go.Bar(
|
113 |
x=list(results.keys()),
|
114 |
y=list(results.values()),
|
|
|
123 |
)
|
124 |
st.plotly_chart(fig, use_container_width=True)
|
125 |
|
126 |
+
# Load the audio for displaying waveform
|
127 |
+
y, sr = librosa.load(uploaded_file, sr=None)
|
128 |
+
|
129 |
+
# Plot the audio waveform
|
130 |
st.subheader("Audio Waveform")
|
131 |
fig_waveform = go.Figure(data=[go.Scatter(y=y, mode='lines', line=dict(color='#1DB954'))])
|
132 |
fig_waveform.update_layout(
|
|
|
142 |
st.error(f"An error occurred while processing the audio: {str(e)}")
|
143 |
st.info("Please try uploading the file again or use a different audio file.")
|
144 |
|
145 |
+
# Footer
|
146 |
st.markdown("""
|
147 |
<div style='text-align: center; margin-top: 2rem;'>
|
148 |
<p>Created with β€οΈ by AI. Powered by Streamlit and Hugging Face Transformers.</p>
|
149 |
</div>
|
150 |
+
""", unsafe_allow_html=True)
|
151 |
+
|