jacob-c commited on
Commit
3dc83fd
·
1 Parent(s): 46e5e67
Files changed (2) hide show
  1. README.md +16 -1
  2. src/classifier.py +41 -6
README.md CHANGED
@@ -75,10 +75,24 @@ python app.py
75
  ## Models Used
76
 
77
  - **Genre Classification**:
78
- - Audio: `superb/wav2vec2-base-superb-gc` (Pre-trained on music genre classification)
79
  - Text: `facebook/bart-large-mnli` (Zero-shot classification)
80
  - **Lyric Generation**: `gpt2-medium`
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  ## Contributing
83
 
84
  Contributions are welcome! Please feel free to submit a Pull Request.
@@ -89,6 +103,7 @@ This project is licensed under the MIT License - see the LICENSE file for detail
89
 
90
  ## Acknowledgments
91
 
 
92
  - Hugging Face for providing the pre-trained models
93
  - Gradio for the web interface framework
94
  - The open-source community for various audio processing libraries
 
75
  ## Models Used
76
 
77
  - **Genre Classification**:
78
+ - Audio: `mit/ast-finetuned-audioset-10-10-0.4593` (MIT's Audio Spectrogram Transformer)
79
  - Text: `facebook/bart-large-mnli` (Zero-shot classification)
80
  - **Lyric Generation**: `gpt2-medium`
81
 
82
+ ## Supported Genres
83
+
84
+ The system supports classification and generation for the following genres:
85
+ - Rock
86
+ - Pop
87
+ - Hip Hop
88
+ - Country
89
+ - Jazz
90
+ - Classical
91
+ - Electronic
92
+ - Blues
93
+ - Reggae
94
+ - Metal
95
+
96
  ## Contributing
97
 
98
  Contributions are welcome! Please feel free to submit a Pull Request.
 
103
 
104
  ## Acknowledgments
105
 
106
+ - MIT for the Audio Spectrogram Transformer model
107
  - Hugging Face for providing the pre-trained models
108
  - Gradio for the web interface framework
109
  - The open-source community for various audio processing libraries
src/classifier.py CHANGED
@@ -13,16 +13,32 @@ class MusicGenreClassifier:
13
  model="facebook/bart-large-mnli"
14
  )
15
 
16
- # For audio classification, we'll use a different pre-trained model
17
  self.audio_classifier = pipeline(
18
  "audio-classification",
19
- model="superb/wav2vec2-base-superb-gc"
20
  )
21
 
 
22
  self.genres = [
23
  "rock", "pop", "hip hop", "country", "jazz",
24
  "classical", "electronic", "blues", "reggae", "metal"
25
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def process_audio(self, audio_path: str) -> torch.Tensor:
28
  """Process audio file to match model requirements."""
@@ -37,16 +53,35 @@ class MusicGenreClassifier:
37
  except Exception as e:
38
  raise ValueError(f"Error processing audio file: {str(e)}")
39
 
 
 
 
 
40
  def classify_audio(self, audio_path: str) -> Tuple[str, float]:
41
  """Classify genre from audio file."""
42
  try:
43
  waveform = self.process_audio(audio_path)
44
- predictions = self.audio_classifier(waveform, top_k=1)
45
- # Get the top prediction
 
46
  if isinstance(predictions, list):
47
  predictions = predictions[0]
48
- top_pred = max(predictions, key=lambda x: x['score'])
49
- return top_pred['label'], top_pred['score']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
  raise ValueError(f"Audio classification failed: {str(e)}")
52
 
 
13
  model="facebook/bart-large-mnli"
14
  )
15
 
16
+ # For audio classification, we'll use MIT's music classification model
17
  self.audio_classifier = pipeline(
18
  "audio-classification",
19
+ model="mit/ast-finetuned-audioset-10-10-0.4593"
20
  )
21
 
22
+ # Define standard genres for classification
23
  self.genres = [
24
  "rock", "pop", "hip hop", "country", "jazz",
25
  "classical", "electronic", "blues", "reggae", "metal"
26
  ]
27
+
28
+ # Mapping from model output labels to our standard genres
29
+ self.label_mapping = {
30
+ "Music": "pop", # Default mapping
31
+ "Rock music": "rock",
32
+ "Pop music": "pop",
33
+ "Hip hop music": "hip hop",
34
+ "Country": "country",
35
+ "Jazz": "jazz",
36
+ "Classical music": "classical",
37
+ "Electronic music": "electronic",
38
+ "Blues": "blues",
39
+ "Reggae": "reggae",
40
+ "Heavy metal": "metal"
41
+ }
42
 
43
  def process_audio(self, audio_path: str) -> torch.Tensor:
44
  """Process audio file to match model requirements."""
 
53
  except Exception as e:
54
  raise ValueError(f"Error processing audio file: {str(e)}")
55
 
56
+ def map_label_to_genre(self, label: str) -> str:
57
+ """Map model output label to standard genre."""
58
+ return self.label_mapping.get(label, "pop") # Default to pop if unknown
59
+
60
  def classify_audio(self, audio_path: str) -> Tuple[str, float]:
61
  """Classify genre from audio file."""
62
  try:
63
  waveform = self.process_audio(audio_path)
64
+ predictions = self.audio_classifier(waveform, top_k=3)
65
+
66
+ # Process predictions
67
  if isinstance(predictions, list):
68
  predictions = predictions[0]
69
+
70
+ # Find the highest scoring music-related prediction
71
+ music_preds = [
72
+ (self.map_label_to_genre(p['label']), p['score'])
73
+ for p in predictions
74
+ if p['label'] in self.label_mapping
75
+ ]
76
+
77
+ if not music_preds:
78
+ # If no music genres found, return default
79
+ return "pop", 0.5
80
+
81
+ # Get the highest scoring genre
82
+ genre, score = max(music_preds, key=lambda x: x[1])
83
+ return genre, score
84
+
85
  except Exception as e:
86
  raise ValueError(f"Audio classification failed: {str(e)}")
87