ibrahim313 commited on
Commit
9217d22
·
verified ·
1 Parent(s): 792958c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -171
app.py CHANGED
@@ -1,196 +1,118 @@
1
  import streamlit as st
2
- import os
3
- import tempfile
4
- from transformers import pipeline, AutoProcessor, AutoModelForAudioClassification
5
  import librosa
6
- import soundfile as sf
7
- from streamlit.components.v1 import html
8
- import base64
9
 
10
- # Set page configuration
11
- st.set_page_config(page_title="Music Genre Classifier", layout="wide")
12
 
13
- # Custom CSS for styling (supports both light and dark modes)
14
- custom_css = """
15
  <style>
16
- .stApp {
17
- transition: background-color 0.3s ease;
18
- }
19
  .main-title {
20
- font-size: 3em;
21
- font-weight: bold;
22
  text-align: center;
23
- margin-bottom: 30px;
 
24
  }
25
  .sub-title {
26
- font-size: 1.5em;
 
27
  text-align: center;
28
- margin-bottom: 20px;
29
  }
30
- .result-container {
31
- border-radius: 10px;
32
- padding: 20px;
33
- margin-top: 20px;
34
  }
35
  .genre-result {
36
- font-size: 2em;
37
  font-weight: bold;
38
  text-align: center;
 
 
39
  }
40
- .confidence-bar {
41
- height: 30px;
42
- border-radius: 15px;
43
- transition: background-color 0.3s ease;
44
- }
45
- /* Light mode styles */
46
- .light-mode .stApp {
47
- background-color: #f0f0f5;
48
- }
49
- .light-mode .main-title, .light-mode .sub-title {
50
- color: #1e1e1e;
51
- }
52
- .light-mode .result-container {
53
- background-color: #ffffff;
54
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
55
- }
56
- .light-mode .genre-result {
57
- color: #2c3e50;
58
- }
59
- .light-mode .confidence-bar {
60
- background-color: #3498db;
61
- }
62
- /* Dark mode styles */
63
- .dark-mode .stApp {
64
- background-color: #1e1e1e;
65
- }
66
- .dark-mode .main-title, .dark-mode .sub-title {
67
- color: #f0f0f5;
68
- }
69
- .dark-mode .result-container {
70
- background-color: #2c2c2c;
71
- box-shadow: 0 4px 6px rgba(255, 255, 255, 0.1);
72
- }
73
- .dark-mode .genre-result {
74
- color: #3498db;
75
- }
76
- .dark-mode .confidence-bar {
77
- background-color: #3498db;
78
  }
79
  </style>
80
- """
81
-
82
- # Render custom CSS
83
- st.markdown(custom_css, unsafe_allow_html=True)
84
 
85
- # Function to load the model
86
  @st.cache_resource
87
  def load_model():
88
- try:
89
- processor = AutoProcessor.from_pretrained("sandychoii/distilhubert-finetuned-gtzan-audio-classification")
90
- model = AutoModelForAudioClassification.from_pretrained("sandychoii/distilhubert-finetuned-gtzan-audio-classification")
91
- pipe = pipeline("audio-classification", model=model, feature_extractor=processor)
92
- return pipe
93
- except Exception as e:
94
- st.error(f"Error loading the model: {str(e)}")
95
- st.info("Please check your internet connection and try again. If the problem persists, the model might be temporarily unavailable.")
96
- return None
97
 
98
- # Load the model
99
  pipe = load_model()
100
 
101
- # Function to classify audio
102
  def classify_audio(audio_file):
103
- try:
104
- # Load audio file
105
- y, sr = librosa.load(audio_file, sr=None)
106
-
107
- # Ensure the audio is at least 3 seconds long (model requirement)
108
- if len(y) < 3 * sr:
109
- y = librosa.util.fix_length(y, size=3 * sr)
110
-
111
- # Classification
112
- result = pipe(y, sampling_rate=sr)
113
- return result
114
- except Exception as e:
115
- st.error(f"Error during classification: {str(e)}")
116
- return None
117
-
118
- # Function to toggle between light and dark mode
119
- def toggle_theme():
120
- if 'theme' not in st.session_state:
121
- st.session_state.theme = 'light'
122
- if st.session_state.theme == 'light':
123
- st.session_state.theme = 'dark'
124
- else:
125
- st.session_state.theme = 'light'
126
- st.experimental_rerun()
127
-
128
- # Main app
129
- def main():
130
- # Set theme class
131
- theme_class = 'light-mode' if st.session_state.get('theme', 'light') == 'light' else 'dark-mode'
132
- st.markdown(f'<div class="{theme_class}">', unsafe_allow_html=True)
133
-
134
- st.markdown("<h1 class='main-title'>🎵 Music Genre Classifier 🎸</h1>", unsafe_allow_html=True)
135
- st.markdown("<p class='sub-title'>Upload a music file and let AI detect its genre!</p>", unsafe_allow_html=True)
136
-
137
- # Theme toggle button
138
- if st.button("Toggle Light/Dark Mode"):
139
- toggle_theme()
140
-
141
- # File uploader
142
- uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"])
143
-
144
- if uploaded_file is not None:
145
- # Display audio player
146
- st.audio(uploaded_file)
147
-
148
- # Classify button
149
- if st.button("Classify Genre"):
150
- if pipe is None:
151
- st.error("Model is not loaded. Please check your internet connection and try again.")
152
- else:
153
- with st.spinner("Analyzing the music... 🎧"):
154
- # Save uploaded file temporarily
155
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
156
- tmp_file.write(uploaded_file.getvalue())
157
- tmp_file_path = tmp_file.name
158
-
159
- # Perform classification
160
- result = classify_audio(tmp_file_path)
161
-
162
- # Remove temporary file
163
- os.unlink(tmp_file_path)
164
-
165
- if result:
166
- # Display results
167
- st.markdown("<div class='result-container'>", unsafe_allow_html=True)
168
- st.markdown(f"<h2 class='genre-result'>Detected Genre: {result[0]['label'].capitalize()}</h2>", unsafe_allow_html=True)
169
-
170
- # Display confidence bar
171
- confidence = result[0]['score']
172
- st.markdown(f"<div class='confidence-bar' style='width: {confidence*100}%;'></div>", unsafe_allow_html=True)
173
- st.write(f"Confidence: {confidence:.2%}")
174
-
175
- # Display top 3 predictions
176
- st.write("Top 3 Predictions:")
177
- for r in result[:3]:
178
- st.write(f"- {r['label'].capitalize()}: {r['score']:.2%}")
179
- st.markdown("</div>", unsafe_allow_html=True)
180
-
181
- # Add information about the model
182
- st.sidebar.title("About")
183
- st.sidebar.info("This app uses a fine-tuned DistilHuBERT model to classify music genres. It can identify genres like rock, pop, hip-hop, classical, and more!")
184
-
185
- # Add a footer
186
- footer_html = """
187
- <div style="position: fixed; bottom: 0; width: 100%; text-align: center; padding: 10px;">
188
- <p>Created with ❤️ by AI. Powered by Streamlit and Hugging Face Transformers.</p>
189
- </div>
190
- """
191
- st.markdown(footer_html, unsafe_allow_html=True)
192
-
193
- st.markdown('</div>', unsafe_allow_html=True)
194
-
195
- if __name__ == "__main__":
196
- main()
 
1
  import streamlit as st
2
+ import time
3
+ from transformers import pipeline
 
4
  import librosa
5
+ import numpy as np
6
+ import plotly.graph_objects as go
 
7
 
8
+ # Set page config
9
+ st.set_page_config(page_title="🎵 Music Genre Classifier", layout="wide")
10
 
11
+ # Custom CSS
12
+ st.markdown("""
13
  <style>
 
 
 
14
  .main-title {
15
+ font-size: 3rem;
16
+ color: #1DB954;
17
  text-align: center;
18
+ padding: 2rem 0;
19
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
20
  }
21
  .sub-title {
22
+ font-size: 1.5rem;
23
+ color: #191414;
24
  text-align: center;
25
+ margin-bottom: 2rem;
26
  }
27
+ .stAudio {
28
+ margin: 2rem auto;
29
+ display: block;
 
30
  }
31
  .genre-result {
32
+ font-size: 2rem;
33
  font-weight: bold;
34
  text-align: center;
35
+ color: #1DB954;
36
+ margin: 1rem 0;
37
  }
38
+ .prediction-time {
39
+ font-size: 1.2rem;
40
+ color: #191414;
41
+ text-align: center;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  }
43
  </style>
44
+ """, unsafe_allow_html=True)
 
 
 
45
 
 
46
  @st.cache_resource
47
  def load_model():
48
+ return pipeline("audio-classification", model="juangtzi/wav2vec2-base-finetuned-gtzan")
 
 
 
 
 
 
 
 
49
 
 
50
  pipe = load_model()
51
 
 
52
  def classify_audio(audio_file):
53
+ start_time = time.time()
54
+ y, sr = librosa.load(audio_file, sr=None)
55
+ preds = pipe(y)
56
+ outputs = {p["label"]: p["score"] for p in preds}
57
+ end_time = time.time()
58
+ prediction_time = end_time - start_time
59
+ return outputs, prediction_time
60
+
61
+ st.markdown("<h1 class='main-title'>🎵 Music Genre Classifier</h1>", unsafe_allow_html=True)
62
+ st.markdown("<p class='sub-title'>Upload a music file and let AI detect its genre!</p>", unsafe_allow_html=True)
63
+
64
+ st.sidebar.title("About")
65
+ st.sidebar.info("""
66
+ This app uses a fine-tuned wav2vec2-base model to classify music genres.
67
+ Model: juangtzi/wav2vec2-base-finetuned-gtzan
68
+ Dataset: GTZAN
69
+ """)
70
+
71
+ uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"])
72
+
73
+ if uploaded_file is not None:
74
+ st.audio(uploaded_file)
75
+
76
+ if st.button("Classify Genre"):
77
+ with st.spinner("Analyzing the music... 🎧"):
78
+ results, pred_time = classify_audio(uploaded_file)
79
+
80
+ # Get top genre
81
+ top_genre = max(results, key=results.get)
82
+
83
+ st.markdown(f"<h2 class='genre-result'>Detected Genre: {top_genre.capitalize()}</h2>", unsafe_allow_html=True)
84
+ st.markdown(f"<p class='prediction-time'>Prediction Time: {pred_time:.2f} seconds</p>", unsafe_allow_html=True)
85
+
86
+ # Create a bar chart using Plotly
87
+ fig = go.Figure(data=[go.Bar(
88
+ x=list(results.keys()),
89
+ y=list(results.values()),
90
+ marker_color='#1DB954'
91
+ )])
92
+ fig.update_layout(
93
+ title="Genre Probabilities",
94
+ xaxis_title="Genre",
95
+ yaxis_title="Probability",
96
+ paper_bgcolor='rgba(0,0,0,0)',
97
+ plot_bgcolor='rgba(0,0,0,0)'
98
+ )
99
+ st.plotly_chart(fig, use_container_width=True)
100
+
101
+ # Display waveform
102
+ st.subheader("Audio Waveform")
103
+ y, sr = librosa.load(uploaded_file, sr=None)
104
+ fig_waveform = go.Figure(data=[go.Scatter(y=y, mode='lines', line=dict(color='#1DB954'))])
105
+ fig_waveform.update_layout(
106
+ title="Audio Waveform",
107
+ xaxis_title="Time",
108
+ yaxis_title="Amplitude",
109
+ paper_bgcolor='rgba(0,0,0,0)',
110
+ plot_bgcolor='rgba(0,0,0,0)'
111
+ )
112
+ st.plotly_chart(fig_waveform, use_container_width=True)
113
+
114
+ st.markdown("""
115
+ <div style='text-align: center; margin-top: 2rem;'>
116
+ <p>Created with ❤️ by AI. Powered by Streamlit and Hugging Face Transformers.</p>
117
+ </div>
118
+ """, unsafe_allow_html=True)