File size: 3,748 Bytes
ed7741a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46e5e67
ed7741a
 
46e5e67
ed7741a
 
 
 
 
 
 
 
 
 
 
 
46e5e67
 
 
 
 
ed7741a
 
 
 
 
 
 
46e5e67
ed7741a
46e5e67
 
ed7741a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torchaudio
import librosa
import numpy as np
from transformers import pipeline
from typing import Union, Tuple, List

class MusicGenreClassifier:
    def __init__(self):
        # Initialize both audio and text classification pipelines
        self.text_classifier = pipeline(
            "zero-shot-classification",
            model="facebook/bart-large-mnli"
        )
        
        # For audio classification, we'll use a different pre-trained model
        self.audio_classifier = pipeline(
            "audio-classification",
            model="superb/wav2vec2-base-superb-gc"
        )
        
        self.genres = [
            "rock", "pop", "hip hop", "country", "jazz",
            "classical", "electronic", "blues", "reggae", "metal"
        ]
    
    def process_audio(self, audio_path: str) -> torch.Tensor:
        """Process audio file to match model requirements."""
        try:
            # Load audio using librosa (handles more formats)
            waveform, sample_rate = librosa.load(audio_path, sr=16000)
            # Convert to torch tensor and ensure proper shape
            waveform = torch.from_numpy(waveform).float()
            if len(waveform.shape) == 1:
                waveform = waveform.unsqueeze(0)
            return waveform
        except Exception as e:
            raise ValueError(f"Error processing audio file: {str(e)}")

    def classify_audio(self, audio_path: str) -> Tuple[str, float]:
        """Classify genre from audio file."""
        try:
            waveform = self.process_audio(audio_path)
            predictions = self.audio_classifier(waveform, top_k=1)
            # Get the top prediction
            if isinstance(predictions, list):
                predictions = predictions[0]
            top_pred = max(predictions, key=lambda x: x['score'])
            return top_pred['label'], top_pred['score']
        except Exception as e:
            raise ValueError(f"Audio classification failed: {str(e)}")

    def classify_text(self, lyrics: str) -> Tuple[str, float]:
        """Classify genre from lyrics text."""
        try:
            # Prepare the hypothesis template for zero-shot classification
            hypothesis_template = "This text contains {} music lyrics."
            
            result = self.text_classifier(
                lyrics,
                candidate_labels=self.genres,
                hypothesis_template=hypothesis_template
            )
            
            return result['labels'][0], result['scores'][0]
        except Exception as e:
            raise ValueError(f"Text classification failed: {str(e)}")

    def predict(self, input_data: str, input_type: str = None) -> dict:
        """
        Main prediction method that handles both audio and text inputs.
        
        Args:
            input_data: Path to audio file or lyrics text
            input_type: Optional, 'audio' or 'text'. If None, will try to auto-detect
            
        Returns:
            dict containing predicted genre and confidence score
        """
        # Try to auto-detect input type if not specified
        if input_type is None:
            input_type = 'audio' if input_data.lower().endswith(('.mp3', '.wav', '.ogg', '.flac')) else 'text'
        
        try:
            if input_type == 'audio':
                genre, confidence = self.classify_audio(input_data)
            else:
                genre, confidence = self.classify_text(input_data)
                
            return {
                'genre': genre,
                'confidence': float(confidence),
                'input_type': input_type
            }
        except Exception as e:
            raise ValueError(f"Prediction failed: {str(e)}")