Spaces:
Running
Running
import torch | |
import torchaudio | |
import librosa | |
import numpy as np | |
from transformers import pipeline | |
from typing import Union, Tuple, List | |
class MusicGenreClassifier: | |
def __init__(self): | |
try: | |
# Initialize both audio and text classification pipelines with auto device mapping | |
self.text_classifier = pipeline( | |
"zero-shot-classification", | |
model="facebook/bart-large-mnli", | |
device_map="auto" | |
) | |
# For audio classification, we'll use MIT's music classification model | |
self.audio_classifier = pipeline( | |
"audio-classification", | |
model="mit/ast-finetuned-audioset-10-10-0.4593", | |
device_map="auto" | |
) | |
except Exception as e: | |
print(f"Warning: GPU initialization failed, falling back to CPU. Error: {str(e)}") | |
# Fall back to CPU if GPU initialization fails | |
self.text_classifier = pipeline( | |
"zero-shot-classification", | |
model="facebook/bart-large-mnli", | |
device="cpu" | |
) | |
self.audio_classifier = pipeline( | |
"audio-classification", | |
model="mit/ast-finetuned-audioset-10-10-0.4593", | |
device="cpu" | |
) | |
# Define standard genres for classification | |
self.genres = [ | |
"rock", "pop", "hip hop", "country", "jazz", | |
"classical", "electronic", "blues", "reggae", "metal" | |
] | |
# Mapping from model output labels to our standard genres | |
self.label_mapping = { | |
"Music": "pop", # Default mapping | |
"Rock music": "rock", | |
"Pop music": "pop", | |
"Hip hop music": "hip hop", | |
"Country": "country", | |
"Jazz": "jazz", | |
"Classical music": "classical", | |
"Electronic music": "electronic", | |
"Blues": "blues", | |
"Reggae": "reggae", | |
"Heavy metal": "metal" | |
} | |
def process_audio(self, audio_path: str) -> torch.Tensor: | |
"""Process audio file to match model requirements.""" | |
try: | |
# Load audio using librosa (handles more formats) | |
waveform, sample_rate = librosa.load(audio_path, sr=16000) | |
# Convert to torch tensor and ensure proper shape | |
waveform = torch.from_numpy(waveform).float() | |
if len(waveform.shape) == 1: | |
waveform = waveform.unsqueeze(0) | |
return waveform | |
except Exception as e: | |
raise ValueError(f"Error processing audio file: {str(e)}") | |
def map_label_to_genre(self, label: str) -> str: | |
"""Map model output label to standard genre.""" | |
return self.label_mapping.get(label, "pop") # Default to pop if unknown | |
def classify_audio(self, audio_path: str) -> Tuple[str, float]: | |
"""Classify genre from audio file.""" | |
try: | |
waveform = self.process_audio(audio_path) | |
predictions = self.audio_classifier(waveform, top_k=3) | |
# Process predictions | |
if isinstance(predictions, list): | |
predictions = predictions[0] | |
# Find the highest scoring music-related prediction | |
music_preds = [ | |
(self.map_label_to_genre(p['label']), p['score']) | |
for p in predictions | |
if p['label'] in self.label_mapping | |
] | |
if not music_preds: | |
# If no music genres found, return default | |
return "pop", 0.5 | |
# Get the highest scoring genre | |
genre, score = max(music_preds, key=lambda x: x[1]) | |
return genre, score | |
except Exception as e: | |
raise ValueError(f"Audio classification failed: {str(e)}") | |
def classify_text(self, lyrics: str) -> Tuple[str, float]: | |
"""Classify genre from lyrics text.""" | |
try: | |
# Prepare the hypothesis template for zero-shot classification | |
hypothesis_template = "This text contains {} music lyrics." | |
result = self.text_classifier( | |
lyrics, | |
candidate_labels=self.genres, | |
hypothesis_template=hypothesis_template | |
) | |
return result['labels'][0], result['scores'][0] | |
except Exception as e: | |
raise ValueError(f"Text classification failed: {str(e)}") | |
def predict(self, input_data: str, input_type: str = None) -> dict: | |
""" | |
Main prediction method that handles both audio and text inputs. | |
Args: | |
input_data: Path to audio file or lyrics text | |
input_type: Optional, 'audio' or 'text'. If None, will try to auto-detect | |
Returns: | |
dict containing predicted genre and confidence score | |
""" | |
# Try to auto-detect input type if not specified | |
if input_type is None: | |
input_type = 'audio' if input_data.lower().endswith(('.mp3', '.wav', '.ogg', '.flac')) else 'text' | |
try: | |
if input_type == 'audio': | |
genre, confidence = self.classify_audio(input_data) | |
else: | |
genre, confidence = self.classify_text(input_data) | |
return { | |
'genre': genre, | |
'confidence': float(confidence), | |
'input_type': input_type | |
} | |
except Exception as e: | |
raise ValueError(f"Prediction failed: {str(e)}") |