Spaces:
Running
Running
import torch | |
import torchaudio | |
import librosa | |
import numpy as np | |
from transformers import pipeline | |
from typing import Union, Tuple, List | |
class MusicGenreClassifier: | |
def __init__(self): | |
# Initialize both audio and text classification pipelines | |
self.text_classifier = pipeline( | |
"zero-shot-classification", | |
model="facebook/bart-large-mnli" | |
) | |
# For audio classification, we'll use a different pre-trained model | |
self.audio_classifier = pipeline( | |
"audio-classification", | |
model="superb/wav2vec2-base-superb-gc" | |
) | |
self.genres = [ | |
"rock", "pop", "hip hop", "country", "jazz", | |
"classical", "electronic", "blues", "reggae", "metal" | |
] | |
def process_audio(self, audio_path: str) -> torch.Tensor: | |
"""Process audio file to match model requirements.""" | |
try: | |
# Load audio using librosa (handles more formats) | |
waveform, sample_rate = librosa.load(audio_path, sr=16000) | |
# Convert to torch tensor and ensure proper shape | |
waveform = torch.from_numpy(waveform).float() | |
if len(waveform.shape) == 1: | |
waveform = waveform.unsqueeze(0) | |
return waveform | |
except Exception as e: | |
raise ValueError(f"Error processing audio file: {str(e)}") | |
def classify_audio(self, audio_path: str) -> Tuple[str, float]: | |
"""Classify genre from audio file.""" | |
try: | |
waveform = self.process_audio(audio_path) | |
predictions = self.audio_classifier(waveform, top_k=1) | |
# Get the top prediction | |
if isinstance(predictions, list): | |
predictions = predictions[0] | |
top_pred = max(predictions, key=lambda x: x['score']) | |
return top_pred['label'], top_pred['score'] | |
except Exception as e: | |
raise ValueError(f"Audio classification failed: {str(e)}") | |
def classify_text(self, lyrics: str) -> Tuple[str, float]: | |
"""Classify genre from lyrics text.""" | |
try: | |
# Prepare the hypothesis template for zero-shot classification | |
hypothesis_template = "This text contains {} music lyrics." | |
result = self.text_classifier( | |
lyrics, | |
candidate_labels=self.genres, | |
hypothesis_template=hypothesis_template | |
) | |
return result['labels'][0], result['scores'][0] | |
except Exception as e: | |
raise ValueError(f"Text classification failed: {str(e)}") | |
def predict(self, input_data: str, input_type: str = None) -> dict: | |
""" | |
Main prediction method that handles both audio and text inputs. | |
Args: | |
input_data: Path to audio file or lyrics text | |
input_type: Optional, 'audio' or 'text'. If None, will try to auto-detect | |
Returns: | |
dict containing predicted genre and confidence score | |
""" | |
# Try to auto-detect input type if not specified | |
if input_type is None: | |
input_type = 'audio' if input_data.lower().endswith(('.mp3', '.wav', '.ogg', '.flac')) else 'text' | |
try: | |
if input_type == 'audio': | |
genre, confidence = self.classify_audio(input_data) | |
else: | |
genre, confidence = self.classify_text(input_data) | |
return { | |
'genre': genre, | |
'confidence': float(confidence), | |
'input_type': input_type | |
} | |
except Exception as e: | |
raise ValueError(f"Prediction failed: {str(e)}") |