File size: 5,713 Bytes
ed7741a
 
 
 
 
 
 
 
 
8599ceb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed7741a
3dc83fd
ed7741a
 
 
 
3dc83fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed7741a
 
 
 
 
 
46e5e67
 
 
 
 
ed7741a
 
 
3dc83fd
 
 
 
ed7741a
 
 
 
3dc83fd
 
 
46e5e67
 
3dc83fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed7741a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
import torchaudio
import librosa
import numpy as np
from transformers import pipeline
from typing import Union, Tuple, List

class MusicGenreClassifier:
    def __init__(self):
        try:
            # Initialize both audio and text classification pipelines with auto device mapping
            self.text_classifier = pipeline(
                "zero-shot-classification",
                model="facebook/bart-large-mnli",
                device_map="auto"
            )
            
            # For audio classification, we'll use MIT's music classification model
            self.audio_classifier = pipeline(
                "audio-classification",
                model="mit/ast-finetuned-audioset-10-10-0.4593",
                device_map="auto"
            )
        except Exception as e:
            print(f"Warning: GPU initialization failed, falling back to CPU. Error: {str(e)}")
            # Fall back to CPU if GPU initialization fails
            self.text_classifier = pipeline(
                "zero-shot-classification",
                model="facebook/bart-large-mnli",
                device="cpu"
            )
            
            self.audio_classifier = pipeline(
                "audio-classification",
                model="mit/ast-finetuned-audioset-10-10-0.4593",
                device="cpu"
            )
        
        # Define standard genres for classification
        self.genres = [
            "rock", "pop", "hip hop", "country", "jazz",
            "classical", "electronic", "blues", "reggae", "metal"
        ]
        
        # Mapping from model output labels to our standard genres
        self.label_mapping = {
            "Music": "pop",  # Default mapping
            "Rock music": "rock",
            "Pop music": "pop",
            "Hip hop music": "hip hop",
            "Country": "country",
            "Jazz": "jazz",
            "Classical music": "classical",
            "Electronic music": "electronic",
            "Blues": "blues",
            "Reggae": "reggae",
            "Heavy metal": "metal"
        }
    
    def process_audio(self, audio_path: str) -> torch.Tensor:
        """Process audio file to match model requirements."""
        try:
            # Load audio using librosa (handles more formats)
            waveform, sample_rate = librosa.load(audio_path, sr=16000)
            # Convert to torch tensor and ensure proper shape
            waveform = torch.from_numpy(waveform).float()
            if len(waveform.shape) == 1:
                waveform = waveform.unsqueeze(0)
            return waveform
        except Exception as e:
            raise ValueError(f"Error processing audio file: {str(e)}")

    def map_label_to_genre(self, label: str) -> str:
        """Map model output label to standard genre."""
        return self.label_mapping.get(label, "pop")  # Default to pop if unknown

    def classify_audio(self, audio_path: str) -> Tuple[str, float]:
        """Classify genre from audio file."""
        try:
            waveform = self.process_audio(audio_path)
            predictions = self.audio_classifier(waveform, top_k=3)
            
            # Process predictions
            if isinstance(predictions, list):
                predictions = predictions[0]
            
            # Find the highest scoring music-related prediction
            music_preds = [
                (self.map_label_to_genre(p['label']), p['score'])
                for p in predictions
                if p['label'] in self.label_mapping
            ]
            
            if not music_preds:
                # If no music genres found, return default
                return "pop", 0.5
            
            # Get the highest scoring genre
            genre, score = max(music_preds, key=lambda x: x[1])
            return genre, score
            
        except Exception as e:
            raise ValueError(f"Audio classification failed: {str(e)}")

    def classify_text(self, lyrics: str) -> Tuple[str, float]:
        """Classify genre from lyrics text."""
        try:
            # Prepare the hypothesis template for zero-shot classification
            hypothesis_template = "This text contains {} music lyrics."
            
            result = self.text_classifier(
                lyrics,
                candidate_labels=self.genres,
                hypothesis_template=hypothesis_template
            )
            
            return result['labels'][0], result['scores'][0]
        except Exception as e:
            raise ValueError(f"Text classification failed: {str(e)}")

    def predict(self, input_data: str, input_type: str = None) -> dict:
        """
        Main prediction method that handles both audio and text inputs.
        
        Args:
            input_data: Path to audio file or lyrics text
            input_type: Optional, 'audio' or 'text'. If None, will try to auto-detect
            
        Returns:
            dict containing predicted genre and confidence score
        """
        # Try to auto-detect input type if not specified
        if input_type is None:
            input_type = 'audio' if input_data.lower().endswith(('.mp3', '.wav', '.ogg', '.flac')) else 'text'
        
        try:
            if input_type == 'audio':
                genre, confidence = self.classify_audio(input_data)
            else:
                genre, confidence = self.classify_text(input_data)
                
            return {
                'genre': genre,
                'confidence': float(confidence),
                'input_type': input_type
            }
        except Exception as e:
            raise ValueError(f"Prediction failed: {str(e)}")