import torch import torchaudio import librosa import numpy as np from transformers import pipeline from typing import Union, Tuple, List class MusicGenreClassifier: def __init__(self): # Initialize both audio and text classification pipelines self.text_classifier = pipeline( "zero-shot-classification", model="facebook/bart-large-mnli" ) # For audio classification, we'll use a different pre-trained model self.audio_classifier = pipeline( "audio-classification", model="superb/wav2vec2-base-superb-gc" ) self.genres = [ "rock", "pop", "hip hop", "country", "jazz", "classical", "electronic", "blues", "reggae", "metal" ] def process_audio(self, audio_path: str) -> torch.Tensor: """Process audio file to match model requirements.""" try: # Load audio using librosa (handles more formats) waveform, sample_rate = librosa.load(audio_path, sr=16000) # Convert to torch tensor and ensure proper shape waveform = torch.from_numpy(waveform).float() if len(waveform.shape) == 1: waveform = waveform.unsqueeze(0) return waveform except Exception as e: raise ValueError(f"Error processing audio file: {str(e)}") def classify_audio(self, audio_path: str) -> Tuple[str, float]: """Classify genre from audio file.""" try: waveform = self.process_audio(audio_path) predictions = self.audio_classifier(waveform, top_k=1) # Get the top prediction if isinstance(predictions, list): predictions = predictions[0] top_pred = max(predictions, key=lambda x: x['score']) return top_pred['label'], top_pred['score'] except Exception as e: raise ValueError(f"Audio classification failed: {str(e)}") def classify_text(self, lyrics: str) -> Tuple[str, float]: """Classify genre from lyrics text.""" try: # Prepare the hypothesis template for zero-shot classification hypothesis_template = "This text contains {} music lyrics." result = self.text_classifier( lyrics, candidate_labels=self.genres, hypothesis_template=hypothesis_template ) return result['labels'][0], result['scores'][0] except Exception as e: raise ValueError(f"Text classification failed: {str(e)}") def predict(self, input_data: str, input_type: str = None) -> dict: """ Main prediction method that handles both audio and text inputs. Args: input_data: Path to audio file or lyrics text input_type: Optional, 'audio' or 'text'. If None, will try to auto-detect Returns: dict containing predicted genre and confidence score """ # Try to auto-detect input type if not specified if input_type is None: input_type = 'audio' if input_data.lower().endswith(('.mp3', '.wav', '.ogg', '.flac')) else 'text' try: if input_type == 'audio': genre, confidence = self.classify_audio(input_data) else: genre, confidence = self.classify_text(input_data) return { 'genre': genre, 'confidence': float(confidence), 'input_type': input_type } except Exception as e: raise ValueError(f"Prediction failed: {str(e)}")