Spaces:
Running
Running
- README.md +16 -1
- src/classifier.py +41 -6
README.md
CHANGED
@@ -75,10 +75,24 @@ python app.py
|
|
75 |
## Models Used
|
76 |
|
77 |
- **Genre Classification**:
|
78 |
-
- Audio: `
|
79 |
- Text: `facebook/bart-large-mnli` (Zero-shot classification)
|
80 |
- **Lyric Generation**: `gpt2-medium`
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
## Contributing
|
83 |
|
84 |
Contributions are welcome! Please feel free to submit a Pull Request.
|
@@ -89,6 +103,7 @@ This project is licensed under the MIT License - see the LICENSE file for detail
|
|
89 |
|
90 |
## Acknowledgments
|
91 |
|
|
|
92 |
- Hugging Face for providing the pre-trained models
|
93 |
- Gradio for the web interface framework
|
94 |
- The open-source community for various audio processing libraries
|
|
|
75 |
## Models Used
|
76 |
|
77 |
- **Genre Classification**:
|
78 |
+
- Audio: `mit/ast-finetuned-audioset-10-10-0.4593` (MIT's Audio Spectrogram Transformer)
|
79 |
- Text: `facebook/bart-large-mnli` (Zero-shot classification)
|
80 |
- **Lyric Generation**: `gpt2-medium`
|
81 |
|
82 |
+
## Supported Genres
|
83 |
+
|
84 |
+
The system supports classification and generation for the following genres:
|
85 |
+
- Rock
|
86 |
+
- Pop
|
87 |
+
- Hip Hop
|
88 |
+
- Country
|
89 |
+
- Jazz
|
90 |
+
- Classical
|
91 |
+
- Electronic
|
92 |
+
- Blues
|
93 |
+
- Reggae
|
94 |
+
- Metal
|
95 |
+
|
96 |
## Contributing
|
97 |
|
98 |
Contributions are welcome! Please feel free to submit a Pull Request.
|
|
|
103 |
|
104 |
## Acknowledgments
|
105 |
|
106 |
+
- MIT for the Audio Spectrogram Transformer model
|
107 |
- Hugging Face for providing the pre-trained models
|
108 |
- Gradio for the web interface framework
|
109 |
- The open-source community for various audio processing libraries
|
src/classifier.py
CHANGED
@@ -13,16 +13,32 @@ class MusicGenreClassifier:
|
|
13 |
model="facebook/bart-large-mnli"
|
14 |
)
|
15 |
|
16 |
-
# For audio classification, we'll use
|
17 |
self.audio_classifier = pipeline(
|
18 |
"audio-classification",
|
19 |
-
model="
|
20 |
)
|
21 |
|
|
|
22 |
self.genres = [
|
23 |
"rock", "pop", "hip hop", "country", "jazz",
|
24 |
"classical", "electronic", "blues", "reggae", "metal"
|
25 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
def process_audio(self, audio_path: str) -> torch.Tensor:
|
28 |
"""Process audio file to match model requirements."""
|
@@ -37,16 +53,35 @@ class MusicGenreClassifier:
|
|
37 |
except Exception as e:
|
38 |
raise ValueError(f"Error processing audio file: {str(e)}")
|
39 |
|
|
|
|
|
|
|
|
|
40 |
def classify_audio(self, audio_path: str) -> Tuple[str, float]:
|
41 |
"""Classify genre from audio file."""
|
42 |
try:
|
43 |
waveform = self.process_audio(audio_path)
|
44 |
-
predictions = self.audio_classifier(waveform, top_k=
|
45 |
-
|
|
|
46 |
if isinstance(predictions, list):
|
47 |
predictions = predictions[0]
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
except Exception as e:
|
51 |
raise ValueError(f"Audio classification failed: {str(e)}")
|
52 |
|
|
|
13 |
model="facebook/bart-large-mnli"
|
14 |
)
|
15 |
|
16 |
+
# For audio classification, we'll use MIT's music classification model
|
17 |
self.audio_classifier = pipeline(
|
18 |
"audio-classification",
|
19 |
+
model="mit/ast-finetuned-audioset-10-10-0.4593"
|
20 |
)
|
21 |
|
22 |
+
# Define standard genres for classification
|
23 |
self.genres = [
|
24 |
"rock", "pop", "hip hop", "country", "jazz",
|
25 |
"classical", "electronic", "blues", "reggae", "metal"
|
26 |
]
|
27 |
+
|
28 |
+
# Mapping from model output labels to our standard genres
|
29 |
+
self.label_mapping = {
|
30 |
+
"Music": "pop", # Default mapping
|
31 |
+
"Rock music": "rock",
|
32 |
+
"Pop music": "pop",
|
33 |
+
"Hip hop music": "hip hop",
|
34 |
+
"Country": "country",
|
35 |
+
"Jazz": "jazz",
|
36 |
+
"Classical music": "classical",
|
37 |
+
"Electronic music": "electronic",
|
38 |
+
"Blues": "blues",
|
39 |
+
"Reggae": "reggae",
|
40 |
+
"Heavy metal": "metal"
|
41 |
+
}
|
42 |
|
43 |
def process_audio(self, audio_path: str) -> torch.Tensor:
|
44 |
"""Process audio file to match model requirements."""
|
|
|
53 |
except Exception as e:
|
54 |
raise ValueError(f"Error processing audio file: {str(e)}")
|
55 |
|
56 |
+
def map_label_to_genre(self, label: str) -> str:
|
57 |
+
"""Map model output label to standard genre."""
|
58 |
+
return self.label_mapping.get(label, "pop") # Default to pop if unknown
|
59 |
+
|
60 |
def classify_audio(self, audio_path: str) -> Tuple[str, float]:
|
61 |
"""Classify genre from audio file."""
|
62 |
try:
|
63 |
waveform = self.process_audio(audio_path)
|
64 |
+
predictions = self.audio_classifier(waveform, top_k=3)
|
65 |
+
|
66 |
+
# Process predictions
|
67 |
if isinstance(predictions, list):
|
68 |
predictions = predictions[0]
|
69 |
+
|
70 |
+
# Find the highest scoring music-related prediction
|
71 |
+
music_preds = [
|
72 |
+
(self.map_label_to_genre(p['label']), p['score'])
|
73 |
+
for p in predictions
|
74 |
+
if p['label'] in self.label_mapping
|
75 |
+
]
|
76 |
+
|
77 |
+
if not music_preds:
|
78 |
+
# If no music genres found, return default
|
79 |
+
return "pop", 0.5
|
80 |
+
|
81 |
+
# Get the highest scoring genre
|
82 |
+
genre, score = max(music_preds, key=lambda x: x[1])
|
83 |
+
return genre, score
|
84 |
+
|
85 |
except Exception as e:
|
86 |
raise ValueError(f"Audio classification failed: {str(e)}")
|
87 |
|