import torch import torchaudio import librosa import numpy as np from transformers import pipeline from typing import Union, Tuple, List class MusicGenreClassifier: def __init__(self): try: # Initialize both audio and text classification pipelines with auto device mapping self.text_classifier = pipeline( "zero-shot-classification", model="facebook/bart-large-mnli", device_map="auto" ) # For audio classification, we'll use MIT's music classification model self.audio_classifier = pipeline( "audio-classification", model="mit/ast-finetuned-audioset-10-10-0.4593", device_map="auto" ) except Exception as e: print(f"Warning: GPU initialization failed, falling back to CPU. Error: {str(e)}") # Fall back to CPU if GPU initialization fails self.text_classifier = pipeline( "zero-shot-classification", model="facebook/bart-large-mnli", device="cpu" ) self.audio_classifier = pipeline( "audio-classification", model="mit/ast-finetuned-audioset-10-10-0.4593", device="cpu" ) # Define standard genres for classification self.genres = [ "rock", "pop", "hip hop", "country", "jazz", "classical", "electronic", "blues", "reggae", "metal" ] # Mapping from model output labels to our standard genres self.label_mapping = { "Music": "pop", # Default mapping "Rock music": "rock", "Pop music": "pop", "Hip hop music": "hip hop", "Country": "country", "Jazz": "jazz", "Classical music": "classical", "Electronic music": "electronic", "Blues": "blues", "Reggae": "reggae", "Heavy metal": "metal" } def process_audio(self, audio_path: str) -> torch.Tensor: """Process audio file to match model requirements.""" try: # Load audio using librosa (handles more formats) waveform, sample_rate = librosa.load(audio_path, sr=16000) # Convert to torch tensor and ensure proper shape waveform = torch.from_numpy(waveform).float() if len(waveform.shape) == 1: waveform = waveform.unsqueeze(0) return waveform except Exception as e: raise ValueError(f"Error processing audio file: {str(e)}") def map_label_to_genre(self, label: str) -> str: """Map model output label to standard genre.""" return self.label_mapping.get(label, "pop") # Default to pop if unknown def classify_audio(self, audio_path: str) -> Tuple[str, float]: """Classify genre from audio file.""" try: waveform = self.process_audio(audio_path) predictions = self.audio_classifier(waveform, top_k=3) # Process predictions if isinstance(predictions, list): predictions = predictions[0] # Find the highest scoring music-related prediction music_preds = [ (self.map_label_to_genre(p['label']), p['score']) for p in predictions if p['label'] in self.label_mapping ] if not music_preds: # If no music genres found, return default return "pop", 0.5 # Get the highest scoring genre genre, score = max(music_preds, key=lambda x: x[1]) return genre, score except Exception as e: raise ValueError(f"Audio classification failed: {str(e)}") def classify_text(self, lyrics: str) -> Tuple[str, float]: """Classify genre from lyrics text.""" try: # Prepare the hypothesis template for zero-shot classification hypothesis_template = "This text contains {} music lyrics." result = self.text_classifier( lyrics, candidate_labels=self.genres, hypothesis_template=hypothesis_template ) return result['labels'][0], result['scores'][0] except Exception as e: raise ValueError(f"Text classification failed: {str(e)}") def predict(self, input_data: str, input_type: str = None) -> dict: """ Main prediction method that handles both audio and text inputs. Args: input_data: Path to audio file or lyrics text input_type: Optional, 'audio' or 'text'. If None, will try to auto-detect Returns: dict containing predicted genre and confidence score """ # Try to auto-detect input type if not specified if input_type is None: input_type = 'audio' if input_data.lower().endswith(('.mp3', '.wav', '.ogg', '.flac')) else 'text' try: if input_type == 'audio': genre, confidence = self.classify_audio(input_data) else: genre, confidence = self.classify_text(input_data) return { 'genre': genre, 'confidence': float(confidence), 'input_type': input_type } except Exception as e: raise ValueError(f"Prediction failed: {str(e)}")