File size: 4,522 Bytes
b5df735
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""
Speaker diarization implementation using pyannote.audio
"""

import os
import torch
from typing import Optional, List, Dict, Any

from ..interfaces.speaker_detector import ISpeakerDetector
from ..utils.config import AudioProcessingConfig
from ..utils.errors import SpeakerDiarizationError, ModelLoadError


class PyannoteSpeikerDetector(ISpeakerDetector):
    """Speaker diarization using pyannote.audio"""
    
    def __init__(self, config: Optional[AudioProcessingConfig] = None):
        self.config = config or AudioProcessingConfig()
        self.device = self._setup_device()
        self.pipeline = None
        self.auth_token = os.environ.get(self.config.hf_token_env_var)
        
        if not self.auth_token:
            print("⚠️ No Hugging Face token found. Speaker diarization will be disabled.")
    
    def _setup_device(self) -> torch.device:
        """Setup and return the best available device"""
        if torch.cuda.is_available():
            return torch.device("cuda")
        else:
            return torch.device("cpu")
    
    async def detect_speakers(
        self,
        audio_file_path: str,
        num_speakers: Optional[int] = None,
        min_speakers: int = 1,
        max_speakers: int = 10
    ) -> Dict[str, Any]:
        """Detect speakers in audio file"""
        
        if not self.auth_token:
            raise SpeakerDiarizationError(
                "Speaker diarization requires Hugging Face token",
                audio_file=audio_file_path
            )
        
        try:
            # Load pipeline if not already loaded
            if self.pipeline is None:
                self.pipeline = self._load_pipeline()
            
            # Perform diarization
            diarization = self.pipeline(audio_file_path)
            
            # Convert to our format
            speakers = {}
            segments = []
            
            for turn, _, speaker in diarization.itertracks(yield_label=True):
                speaker_id = f"SPEAKER_{speaker.split('_')[-1].zfill(2)}"
                segments.append({
                    "start": turn.start,
                    "end": turn.end,
                    "speaker": speaker_id
                })
                
                if speaker_id not in speakers:
                    speakers[speaker_id] = {
                        "id": speaker_id,
                        "total_time": 0.0,
                        "segments": []
                    }
                
                speakers[speaker_id]["total_time"] += turn.end - turn.start
                speakers[speaker_id]["segments"].append({
                    "start": turn.start,
                    "end": turn.end
                })
            
            return {
                "speaker_count": len(speakers),
                "speakers": speakers,
                "segments": segments,
                "audio_file": audio_file_path
            }
            
        except Exception as e:
            raise SpeakerDiarizationError(
                f"Speaker detection failed: {str(e)}",
                audio_file=audio_file_path
            )
    
    def _load_pipeline(self):
        """Load pyannote speaker diarization pipeline"""
        try:
            # Suppress warnings
            import warnings
            warnings.filterwarnings("ignore", category=UserWarning, module="pyannote")
            warnings.filterwarnings("ignore", category=UserWarning, module="pytorch_lightning")
            warnings.filterwarnings("ignore", category=FutureWarning, module="pytorch_lightning")
            
            from pyannote.audio import Pipeline
            
            print("📥 Loading speaker diarization pipeline...")
            pipeline = Pipeline.from_pretrained(
                self.config.speaker_diarization_model,
                use_auth_token=self.auth_token
            )
            pipeline.to(self.device)
            
            return pipeline
            
        except Exception as e:
            raise ModelLoadError(
                f"Failed to load speaker diarization pipeline: {str(e)}",
                model_name=self.config.speaker_diarization_model
            )
    
    def get_supported_models(self) -> List[str]:
        """Get list of supported speaker diarization models"""
        return [self.config.speaker_diarization_model]
    
    def is_available(self) -> bool:
        """Check if speaker diarization is available"""
        return self.auth_token is not None