Spaces:
Running
Running
File size: 5,713 Bytes
ed7741a 8599ceb ed7741a 3dc83fd ed7741a 3dc83fd ed7741a 46e5e67 ed7741a 3dc83fd ed7741a 3dc83fd 46e5e67 3dc83fd ed7741a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import torchaudio
import librosa
import numpy as np
from transformers import pipeline
from typing import Union, Tuple, List
class MusicGenreClassifier:
def __init__(self):
try:
# Initialize both audio and text classification pipelines with auto device mapping
self.text_classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device_map="auto"
)
# For audio classification, we'll use MIT's music classification model
self.audio_classifier = pipeline(
"audio-classification",
model="mit/ast-finetuned-audioset-10-10-0.4593",
device_map="auto"
)
except Exception as e:
print(f"Warning: GPU initialization failed, falling back to CPU. Error: {str(e)}")
# Fall back to CPU if GPU initialization fails
self.text_classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device="cpu"
)
self.audio_classifier = pipeline(
"audio-classification",
model="mit/ast-finetuned-audioset-10-10-0.4593",
device="cpu"
)
# Define standard genres for classification
self.genres = [
"rock", "pop", "hip hop", "country", "jazz",
"classical", "electronic", "blues", "reggae", "metal"
]
# Mapping from model output labels to our standard genres
self.label_mapping = {
"Music": "pop", # Default mapping
"Rock music": "rock",
"Pop music": "pop",
"Hip hop music": "hip hop",
"Country": "country",
"Jazz": "jazz",
"Classical music": "classical",
"Electronic music": "electronic",
"Blues": "blues",
"Reggae": "reggae",
"Heavy metal": "metal"
}
def process_audio(self, audio_path: str) -> torch.Tensor:
"""Process audio file to match model requirements."""
try:
# Load audio using librosa (handles more formats)
waveform, sample_rate = librosa.load(audio_path, sr=16000)
# Convert to torch tensor and ensure proper shape
waveform = torch.from_numpy(waveform).float()
if len(waveform.shape) == 1:
waveform = waveform.unsqueeze(0)
return waveform
except Exception as e:
raise ValueError(f"Error processing audio file: {str(e)}")
def map_label_to_genre(self, label: str) -> str:
"""Map model output label to standard genre."""
return self.label_mapping.get(label, "pop") # Default to pop if unknown
def classify_audio(self, audio_path: str) -> Tuple[str, float]:
"""Classify genre from audio file."""
try:
waveform = self.process_audio(audio_path)
predictions = self.audio_classifier(waveform, top_k=3)
# Process predictions
if isinstance(predictions, list):
predictions = predictions[0]
# Find the highest scoring music-related prediction
music_preds = [
(self.map_label_to_genre(p['label']), p['score'])
for p in predictions
if p['label'] in self.label_mapping
]
if not music_preds:
# If no music genres found, return default
return "pop", 0.5
# Get the highest scoring genre
genre, score = max(music_preds, key=lambda x: x[1])
return genre, score
except Exception as e:
raise ValueError(f"Audio classification failed: {str(e)}")
def classify_text(self, lyrics: str) -> Tuple[str, float]:
"""Classify genre from lyrics text."""
try:
# Prepare the hypothesis template for zero-shot classification
hypothesis_template = "This text contains {} music lyrics."
result = self.text_classifier(
lyrics,
candidate_labels=self.genres,
hypothesis_template=hypothesis_template
)
return result['labels'][0], result['scores'][0]
except Exception as e:
raise ValueError(f"Text classification failed: {str(e)}")
def predict(self, input_data: str, input_type: str = None) -> dict:
"""
Main prediction method that handles both audio and text inputs.
Args:
input_data: Path to audio file or lyrics text
input_type: Optional, 'audio' or 'text'. If None, will try to auto-detect
Returns:
dict containing predicted genre and confidence score
"""
# Try to auto-detect input type if not specified
if input_type is None:
input_type = 'audio' if input_data.lower().endswith(('.mp3', '.wav', '.ogg', '.flac')) else 'text'
try:
if input_type == 'audio':
genre, confidence = self.classify_audio(input_data)
else:
genre, confidence = self.classify_text(input_data)
return {
'genre': genre,
'confidence': float(confidence),
'input_type': input_type
}
except Exception as e:
raise ValueError(f"Prediction failed: {str(e)}") |