ibrahim313 commited on
Commit
9584b8e
·
verified ·
1 Parent(s): df4f947

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import tempfile
4
+ from transformers import pipeline
5
+ import librosa
6
+ import soundfile as sf
7
+ from streamlit.components.v1 import html
8
+ import base64
9
+
10
+ # Set page configuration
11
+ st.set_page_config(page_title="Music Genre Classifier", layout="wide")
12
+
13
+ # Custom CSS for styling
14
+ custom_css = """
15
+ <style>
16
+ .stApp {
17
+ background-color: #f0f0f5;
18
+ }
19
+ .main-title {
20
+ color: #1e1e1e;
21
+ font-size: 3em;
22
+ font-weight: bold;
23
+ text-align: center;
24
+ margin-bottom: 30px;
25
+ }
26
+ .sub-title {
27
+ color: #4a4a4a;
28
+ font-size: 1.5em;
29
+ text-align: center;
30
+ margin-bottom: 20px;
31
+ }
32
+ .result-container {
33
+ background-color: #ffffff;
34
+ border-radius: 10px;
35
+ padding: 20px;
36
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
37
+ }
38
+ .genre-result {
39
+ font-size: 2em;
40
+ font-weight: bold;
41
+ text-align: center;
42
+ color: #2c3e50;
43
+ }
44
+ .confidence-bar {
45
+ height: 30px;
46
+ background-color: #3498db;
47
+ border-radius: 15px;
48
+ }
49
+ </style>
50
+ """
51
+
52
+ # Render custom CSS
53
+ st.markdown(custom_css, unsafe_allow_html=True)
54
+
55
+ # Load the audio classification model
56
+ @st.cache_resource
57
+ def load_model():
58
+ try:
59
+ return pipeline("audio-classification", model="sandychoii/distilhubert-finetuned-gtzan-audio-classification")
60
+ except Exception as e:
61
+ st.error(f"Error loading the model: {str(e)}")
62
+ return None
63
+
64
+ pipe = load_model()
65
+
66
+ # Function to classify audio
67
+ def classify_audio(audio_file):
68
+ try:
69
+ # Load audio file
70
+ y, sr = librosa.load(audio_file, sr=None)
71
+
72
+ # Ensure the audio is at least 3 seconds long (model requirement)
73
+ if len(y) < 3 * sr:
74
+ y = librosa.util.fix_length(y, size=3 * sr)
75
+
76
+ # Classification
77
+ result = pipe(y, sampling_rate=sr)
78
+ return result
79
+ except Exception as e:
80
+ st.error(f"Error during classification: {str(e)}")
81
+ return None
82
+
83
+ # Main app
84
+ def main():
85
+ st.markdown("<h1 class='main-title'>🎵 Music Genre Classifier 🎸</h1>", unsafe_allow_html=True)
86
+ st.markdown("<p class='sub-title'>Upload a music file and let AI detect its genre!</p>", unsafe_allow_html=True)
87
+
88
+ # File uploader
89
+ uploaded_file = st.file_uploader("Choose an audio file", type=["wav", "mp3", "ogg"])
90
+
91
+ if uploaded_file is not None:
92
+ # Display audio player
93
+ st.audio(uploaded_file)
94
+
95
+ # Classify button
96
+ if st.button("Classify Genre"):
97
+ with st.spinner("Analyzing the music... 🎧"):
98
+ # Save uploaded file temporarily
99
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
100
+ tmp_file.write(uploaded_file.getvalue())
101
+ tmp_file_path = tmp_file.name
102
+
103
+ # Perform classification
104
+ result = classify_audio(tmp_file_path)
105
+
106
+ # Remove temporary file
107
+ os.unlink(tmp_file_path)
108
+
109
+ if result:
110
+ # Display results
111
+ st.markdown("<div class='result-container'>", unsafe_allow_html=True)
112
+ st.markdown(f"<h2 class='genre-result'>Detected Genre: {result[0]['label'].capitalize()}</h2>", unsafe_allow_html=True)
113
+
114
+ # Display confidence bar
115
+ confidence = result[0]['score']
116
+ st.markdown(f"<div class='confidence-bar' style='width: {confidence*100}%;'></div>", unsafe_allow_html=True)
117
+ st.write(f"Confidence: {confidence:.2%}")
118
+
119
+ # Display top 3 predictions
120
+ st.write("Top 3 Predictions:")
121
+ for r in result[:3]:
122
+ st.write(f"- {r['label'].capitalize()}: {r['score']:.2%}")
123
+ st.markdown("</div>", unsafe_allow_html=True)
124
+
125
+ # Add information about the model
126
+ st.sidebar.title("About")
127
+ st.sidebar.info("This app uses a fine-tuned DistilHuBERT model to classify music genres. It can identify genres like rock, pop, hip-hop, classical, and more!")
128
+
129
+ # Add a footer
130
+ footer_html = """
131
+ <div style="position: fixed; bottom: 0; width: 100%; text-align: center; padding: 10px; background-color: #f0f0f5;">
132
+ <p>Created with ❤️ by AI. Powered by Streamlit and Hugging Face Transformers.</p>
133
+ </div>
134
+ """
135
+ st.markdown(footer_html, unsafe_allow_html=True)
136
+
137
+ if __name__ == "__main__":
138
+ main()