ibrahim313 commited on
Commit
dd79c10
Β·
verified Β·
1 Parent(s): 85d10dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -41
app.py CHANGED
@@ -4,11 +4,13 @@ from transformers import pipeline
4
  import librosa
5
  import numpy as np
6
  import plotly.graph_objects as go
 
 
7
 
8
  # Set page config
9
  st.set_page_config(page_title="🎡 Music Genre Classifier", layout="wide")
10
 
11
- # Custom CSS
12
  st.markdown("""
13
  <style>
14
  .main-title {
@@ -51,12 +53,22 @@ pipe = load_model()
51
 
52
  def classify_audio(audio_file):
53
  start_time = time.time()
54
- y, sr = librosa.load(audio_file, sr=None)
55
- preds = pipe(y)
56
- outputs = {p["label"]: p["score"] for p in preds}
57
- end_time = time.time()
58
- prediction_time = end_time - start_time
59
- return outputs, prediction_time
 
 
 
 
 
 
 
 
 
 
60
 
61
  st.markdown("<h1 class='main-title'>🎡 Music Genre Classifier</h1>", unsafe_allow_html=True)
62
  st.markdown("<p class='sub-title'>Upload a music file and let AI detect its genre!</p>", unsafe_allow_html=True)
@@ -75,41 +87,45 @@ if uploaded_file is not None:
75
 
76
  if st.button("Classify Genre"):
77
  with st.spinner("Analyzing the music... 🎧"):
78
- results, pred_time = classify_audio(uploaded_file)
79
-
80
- # Get top genre
81
- top_genre = max(results, key=results.get)
82
-
83
- st.markdown(f"<h2 class='genre-result'>Detected Genre: {top_genre.capitalize()}</h2>", unsafe_allow_html=True)
84
- st.markdown(f"<p class='prediction-time'>Prediction Time: {pred_time:.2f} seconds</p>", unsafe_allow_html=True)
85
-
86
- # Create a bar chart using Plotly
87
- fig = go.Figure(data=[go.Bar(
88
- x=list(results.keys()),
89
- y=list(results.values()),
90
- marker_color='#1DB954'
91
- )])
92
- fig.update_layout(
93
- title="Genre Probabilities",
94
- xaxis_title="Genre",
95
- yaxis_title="Probability",
96
- paper_bgcolor='rgba(0,0,0,0)',
97
- plot_bgcolor='rgba(0,0,0,0)'
98
- )
99
- st.plotly_chart(fig, use_container_width=True)
 
100
 
101
- # Display waveform
102
- st.subheader("Audio Waveform")
103
- y, sr = librosa.load(uploaded_file, sr=None)
104
- fig_waveform = go.Figure(data=[go.Scatter(y=y, mode='lines', line=dict(color='#1DB954'))])
105
- fig_waveform.update_layout(
106
- title="Audio Waveform",
107
- xaxis_title="Time",
108
- yaxis_title="Amplitude",
109
- paper_bgcolor='rgba(0,0,0,0)',
110
- plot_bgcolor='rgba(0,0,0,0)'
111
- )
112
- st.plotly_chart(fig_waveform, use_container_width=True)
 
 
 
113
 
114
  st.markdown("""
115
  <div style='text-align: center; margin-top: 2rem;'>
 
4
  import librosa
5
  import numpy as np
6
  import plotly.graph_objects as go
7
+ import tempfile
8
+ import os
9
 
10
  # Set page config
11
  st.set_page_config(page_title="🎡 Music Genre Classifier", layout="wide")
12
 
13
+ # Custom CSS (unchanged)
14
  st.markdown("""
15
  <style>
16
  .main-title {
 
53
 
54
  def classify_audio(audio_file):
55
  start_time = time.time()
56
+
57
+ # Save the uploaded file to a temporary file
58
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
59
+ tmp_file.write(audio_file.getvalue())
60
+ tmp_file_path = tmp_file.name
61
+
62
+ try:
63
+ y, sr = librosa.load(tmp_file_path, sr=None)
64
+ preds = pipe(y)
65
+ outputs = {p["label"]: p["score"] for p in preds}
66
+ end_time = time.time()
67
+ prediction_time = end_time - start_time
68
+ return outputs, prediction_time, y, sr
69
+ finally:
70
+ # Make sure to remove the temporary file
71
+ os.unlink(tmp_file_path)
72
 
73
  st.markdown("<h1 class='main-title'>🎡 Music Genre Classifier</h1>", unsafe_allow_html=True)
74
  st.markdown("<p class='sub-title'>Upload a music file and let AI detect its genre!</p>", unsafe_allow_html=True)
 
87
 
88
  if st.button("Classify Genre"):
89
  with st.spinner("Analyzing the music... 🎧"):
90
+ try:
91
+ results, pred_time, y, sr = classify_audio(uploaded_file)
92
+
93
+ # Get top genre
94
+ top_genre = max(results, key=results.get)
95
+
96
+ st.markdown(f"<h2 class='genre-result'>Detected Genre: {top_genre.capitalize()}</h2>", unsafe_allow_html=True)
97
+ st.markdown(f"<p class='prediction-time'>Prediction Time: {pred_time:.2f} seconds</p>", unsafe_allow_html=True)
98
+
99
+ # Create a bar chart using Plotly
100
+ fig = go.Figure(data=[go.Bar(
101
+ x=list(results.keys()),
102
+ y=list(results.values()),
103
+ marker_color='#1DB954'
104
+ )])
105
+ fig.update_layout(
106
+ title="Genre Probabilities",
107
+ xaxis_title="Genre",
108
+ yaxis_title="Probability",
109
+ paper_bgcolor='rgba(0,0,0,0)',
110
+ plot_bgcolor='rgba(0,0,0,0)'
111
+ )
112
+ st.plotly_chart(fig, use_container_width=True)
113
 
114
+ # Display waveform
115
+ st.subheader("Audio Waveform")
116
+ fig_waveform = go.Figure(data=[go.Scatter(y=y, mode='lines', line=dict(color='#1DB954'))])
117
+ fig_waveform.update_layout(
118
+ title="Audio Waveform",
119
+ xaxis_title="Time",
120
+ yaxis_title="Amplitude",
121
+ paper_bgcolor='rgba(0,0,0,0)',
122
+ plot_bgcolor='rgba(0,0,0,0)'
123
+ )
124
+ st.plotly_chart(fig_waveform, use_container_width=True)
125
+
126
+ except Exception as e:
127
+ st.error(f"An error occurred while processing the audio: {str(e)}")
128
+ st.info("Please try uploading the file again or use a different audio file.")
129
 
130
  st.markdown("""
131
  <div style='text-align: center; margin-top: 2rem;'>