fyp_start_space / src /classifier.py
jacob-c's picture
.
8599ceb
raw
history blame
5.71 kB
import torch
import torchaudio
import librosa
import numpy as np
from transformers import pipeline
from typing import Union, Tuple, List
class MusicGenreClassifier:
def __init__(self):
try:
# Initialize both audio and text classification pipelines with auto device mapping
self.text_classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device_map="auto"
)
# For audio classification, we'll use MIT's music classification model
self.audio_classifier = pipeline(
"audio-classification",
model="mit/ast-finetuned-audioset-10-10-0.4593",
device_map="auto"
)
except Exception as e:
print(f"Warning: GPU initialization failed, falling back to CPU. Error: {str(e)}")
# Fall back to CPU if GPU initialization fails
self.text_classifier = pipeline(
"zero-shot-classification",
model="facebook/bart-large-mnli",
device="cpu"
)
self.audio_classifier = pipeline(
"audio-classification",
model="mit/ast-finetuned-audioset-10-10-0.4593",
device="cpu"
)
# Define standard genres for classification
self.genres = [
"rock", "pop", "hip hop", "country", "jazz",
"classical", "electronic", "blues", "reggae", "metal"
]
# Mapping from model output labels to our standard genres
self.label_mapping = {
"Music": "pop", # Default mapping
"Rock music": "rock",
"Pop music": "pop",
"Hip hop music": "hip hop",
"Country": "country",
"Jazz": "jazz",
"Classical music": "classical",
"Electronic music": "electronic",
"Blues": "blues",
"Reggae": "reggae",
"Heavy metal": "metal"
}
def process_audio(self, audio_path: str) -> torch.Tensor:
"""Process audio file to match model requirements."""
try:
# Load audio using librosa (handles more formats)
waveform, sample_rate = librosa.load(audio_path, sr=16000)
# Convert to torch tensor and ensure proper shape
waveform = torch.from_numpy(waveform).float()
if len(waveform.shape) == 1:
waveform = waveform.unsqueeze(0)
return waveform
except Exception as e:
raise ValueError(f"Error processing audio file: {str(e)}")
def map_label_to_genre(self, label: str) -> str:
"""Map model output label to standard genre."""
return self.label_mapping.get(label, "pop") # Default to pop if unknown
def classify_audio(self, audio_path: str) -> Tuple[str, float]:
"""Classify genre from audio file."""
try:
waveform = self.process_audio(audio_path)
predictions = self.audio_classifier(waveform, top_k=3)
# Process predictions
if isinstance(predictions, list):
predictions = predictions[0]
# Find the highest scoring music-related prediction
music_preds = [
(self.map_label_to_genre(p['label']), p['score'])
for p in predictions
if p['label'] in self.label_mapping
]
if not music_preds:
# If no music genres found, return default
return "pop", 0.5
# Get the highest scoring genre
genre, score = max(music_preds, key=lambda x: x[1])
return genre, score
except Exception as e:
raise ValueError(f"Audio classification failed: {str(e)}")
def classify_text(self, lyrics: str) -> Tuple[str, float]:
"""Classify genre from lyrics text."""
try:
# Prepare the hypothesis template for zero-shot classification
hypothesis_template = "This text contains {} music lyrics."
result = self.text_classifier(
lyrics,
candidate_labels=self.genres,
hypothesis_template=hypothesis_template
)
return result['labels'][0], result['scores'][0]
except Exception as e:
raise ValueError(f"Text classification failed: {str(e)}")
def predict(self, input_data: str, input_type: str = None) -> dict:
"""
Main prediction method that handles both audio and text inputs.
Args:
input_data: Path to audio file or lyrics text
input_type: Optional, 'audio' or 'text'. If None, will try to auto-detect
Returns:
dict containing predicted genre and confidence score
"""
# Try to auto-detect input type if not specified
if input_type is None:
input_type = 'audio' if input_data.lower().endswith(('.mp3', '.wav', '.ogg', '.flac')) else 'text'
try:
if input_type == 'audio':
genre, confidence = self.classify_audio(input_data)
else:
genre, confidence = self.classify_text(input_data)
return {
'genre': genre,
'confidence': float(confidence),
'input_type': input_type
}
except Exception as e:
raise ValueError(f"Prediction failed: {str(e)}")