ibrahim313 commited on
Commit
9dc92cb
Β·
verified Β·
1 Parent(s): dd79c10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -16
app.py CHANGED
@@ -6,11 +6,12 @@ import numpy as np
6
  import plotly.graph_objects as go
7
  import tempfile
8
  import os
 
9
 
10
  # Set page config
11
  st.set_page_config(page_title="🎡 Music Genre Classifier", layout="wide")
12
 
13
- # Custom CSS (unchanged)
14
  st.markdown("""
15
  <style>
16
  .main-title {
@@ -51,28 +52,35 @@ def load_model():
51
 
52
  pipe = load_model()
53
 
 
 
 
 
 
 
 
54
  def classify_audio(audio_file):
 
55
  start_time = time.time()
56
-
57
- # Save the uploaded file to a temporary file
58
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
59
- tmp_file.write(audio_file.getvalue())
60
- tmp_file_path = tmp_file.name
61
 
62
  try:
63
- y, sr = librosa.load(tmp_file_path, sr=None)
64
- preds = pipe(y)
65
  outputs = {p["label"]: p["score"] for p in preds}
66
  end_time = time.time()
67
  prediction_time = end_time - start_time
68
- return outputs, prediction_time, y, sr
69
  finally:
70
- # Make sure to remove the temporary file
71
- os.unlink(tmp_file_path)
72
 
 
73
  st.markdown("<h1 class='main-title'>🎡 Music Genre Classifier</h1>", unsafe_allow_html=True)
74
  st.markdown("<p class='sub-title'>Upload a music file and let AI detect its genre!</p>", unsafe_allow_html=True)
75
 
 
76
  st.sidebar.title("About")
77
  st.sidebar.info("""
78
  This app uses a fine-tuned wav2vec2-base model to classify music genres.
@@ -80,23 +88,27 @@ Model: juangtzi/wav2vec2-base-finetuned-gtzan
80
  Dataset: GTZAN
81
  """)
82
 
 
83
  uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"])
84
 
85
  if uploaded_file is not None:
 
86
  st.audio(uploaded_file)
87
 
 
88
  if st.button("Classify Genre"):
89
  with st.spinner("Analyzing the music... 🎧"):
90
  try:
91
- results, pred_time, y, sr = classify_audio(uploaded_file)
92
 
93
- # Get top genre
94
  top_genre = max(results, key=results.get)
95
 
 
96
  st.markdown(f"<h2 class='genre-result'>Detected Genre: {top_genre.capitalize()}</h2>", unsafe_allow_html=True)
97
  st.markdown(f"<p class='prediction-time'>Prediction Time: {pred_time:.2f} seconds</p>", unsafe_allow_html=True)
98
 
99
- # Create a bar chart using Plotly
100
  fig = go.Figure(data=[go.Bar(
101
  x=list(results.keys()),
102
  y=list(results.values()),
@@ -111,7 +123,10 @@ if uploaded_file is not None:
111
  )
112
  st.plotly_chart(fig, use_container_width=True)
113
 
114
- # Display waveform
 
 
 
115
  st.subheader("Audio Waveform")
116
  fig_waveform = go.Figure(data=[go.Scatter(y=y, mode='lines', line=dict(color='#1DB954'))])
117
  fig_waveform.update_layout(
@@ -127,8 +142,10 @@ if uploaded_file is not None:
127
  st.error(f"An error occurred while processing the audio: {str(e)}")
128
  st.info("Please try uploading the file again or use a different audio file.")
129
 
 
130
  st.markdown("""
131
  <div style='text-align: center; margin-top: 2rem;'>
132
  <p>Created with ❀️ by AI. Powered by Streamlit and Hugging Face Transformers.</p>
133
  </div>
134
- """, unsafe_allow_html=True)
 
 
6
  import plotly.graph_objects as go
7
  import tempfile
8
  import os
9
+ import soundfile as sf
10
 
11
  # Set page config
12
  st.set_page_config(page_title="🎡 Music Genre Classifier", layout="wide")
13
 
14
+ # Custom CSS for UI
15
  st.markdown("""
16
  <style>
17
  .main-title {
 
52
 
53
  pipe = load_model()
54
 
55
+ def convert_to_wav(audio_file):
56
+ """Converts uploaded audio file to WAV format."""
57
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_wav:
58
+ y, sr = librosa.load(audio_file, sr=None)
59
+ sf.write(tmp_wav.name, y, sr)
60
+ return tmp_wav.name
61
+
62
  def classify_audio(audio_file):
63
+ """Classifies the audio file using the loaded model."""
64
  start_time = time.time()
65
+
66
+ # Convert to WAV format before passing to the model
67
+ wav_file = convert_to_wav(audio_file)
 
 
68
 
69
  try:
70
+ # Use the wav file with the model
71
+ preds = pipe(wav_file)
72
  outputs = {p["label"]: p["score"] for p in preds}
73
  end_time = time.time()
74
  prediction_time = end_time - start_time
75
+ return outputs, prediction_time
76
  finally:
77
+ os.unlink(wav_file) # Remove the temp file
 
78
 
79
+ # Page title and subtitle
80
  st.markdown("<h1 class='main-title'>🎡 Music Genre Classifier</h1>", unsafe_allow_html=True)
81
  st.markdown("<p class='sub-title'>Upload a music file and let AI detect its genre!</p>", unsafe_allow_html=True)
82
 
83
+ # Sidebar with model and dataset information
84
  st.sidebar.title("About")
85
  st.sidebar.info("""
86
  This app uses a fine-tuned wav2vec2-base model to classify music genres.
 
88
  Dataset: GTZAN
89
  """)
90
 
91
+ # Upload file section
92
  uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"])
93
 
94
  if uploaded_file is not None:
95
+ # Display the uploaded audio file
96
  st.audio(uploaded_file)
97
 
98
+ # Classify the uploaded audio
99
  if st.button("Classify Genre"):
100
  with st.spinner("Analyzing the music... 🎧"):
101
  try:
102
+ results, pred_time = classify_audio(uploaded_file)
103
 
104
+ # Get the top predicted genre
105
  top_genre = max(results, key=results.get)
106
 
107
+ # Display the top predicted genre
108
  st.markdown(f"<h2 class='genre-result'>Detected Genre: {top_genre.capitalize()}</h2>", unsafe_allow_html=True)
109
  st.markdown(f"<p class='prediction-time'>Prediction Time: {pred_time:.2f} seconds</p>", unsafe_allow_html=True)
110
 
111
+ # Plot the genre probabilities as a bar chart
112
  fig = go.Figure(data=[go.Bar(
113
  x=list(results.keys()),
114
  y=list(results.values()),
 
123
  )
124
  st.plotly_chart(fig, use_container_width=True)
125
 
126
+ # Load the audio for displaying waveform
127
+ y, sr = librosa.load(uploaded_file, sr=None)
128
+
129
+ # Plot the audio waveform
130
  st.subheader("Audio Waveform")
131
  fig_waveform = go.Figure(data=[go.Scatter(y=y, mode='lines', line=dict(color='#1DB954'))])
132
  fig_waveform.update_layout(
 
142
  st.error(f"An error occurred while processing the audio: {str(e)}")
143
  st.info("Please try uploading the file again or use a different audio file.")
144
 
145
+ # Footer
146
  st.markdown("""
147
  <div style='text-align: center; margin-top: 2rem;'>
148
  <p>Created with ❀️ by AI. Powered by Streamlit and Hugging Face Transformers.</p>
149
  </div>
150
+ """, unsafe_allow_html=True)
151
+