File size: 4,908 Bytes
ca494e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Alias module to redirect whisper imports to whisperx.
This allows OuteTTS to use whisperx instead of the standard whisper package.
"""

import sys
import importlib.util

def setup_whisper_alias():
    """Setup alias so that 'import whisper' uses whisperx instead."""
    try:
        # Check if whisperx is available
        whisperx_spec = importlib.util.find_spec("whisperx")
        if whisperx_spec is None:
            print("Warning: whisperx not found, falling back to regular whisper")
            return
        
        # Import whisperx
        import whisperx
        
        # Create a module wrapper that provides whisper-like interface
        class WhisperAlias:
            def __init__(self):
                self.model = whisperx.WhisperModel if hasattr(whisperx, 'WhisperModel') else None
                self.load_model = self._load_model
                
            def _load_model(self, name, **kwargs):
                """Load model with whisperx compatible interface."""
                # Create WhisperX model instance
                device = "cuda" if kwargs.get("device", "auto") == "cuda" else "cpu"
                compute_type = "float16" if device == "cuda" else "int8"
                
                model = whisperx.load_model(
                    name,
                    device=device,
                    compute_type=compute_type
                )
                
                return WhisperXModelWrapper(model, device)
        
        class WhisperXModelWrapper:
            """Wrapper to make whisperx compatible with whisper interface."""
            
            def __init__(self, model, device):
                self.model = model
                self.device = device
                
            def transcribe(self, audio, **kwargs):
                """Transcribe audio with whisper-compatible interface."""
                # Store original word_timestamps setting
                original_word_timestamps = kwargs.get('word_timestamps', False)
                
                # Load audio if it's a file path
                if isinstance(audio, str):
                    audio_data = whisperx.load_audio(audio)
                else:
                    audio_data = audio
                
                # Use whisperx's transcribe method
                batch_size = kwargs.get('batch_size', 16)
                result = self.model.transcribe(audio_data, batch_size=batch_size)
                
                # If word timestamps are requested, perform alignment
                if original_word_timestamps and result.get("segments"):
                    try:
                        # Load alignment model
                        model_a, metadata = whisperx.load_align_model(
                            language_code=result.get("language", "en"),
                            device=self.device
                        )
                        
                        # Align the segments
                        result = whisperx.align(
                            result["segments"],
                            model_a,
                            metadata,
                            audio_data,
                            self.device,
                            return_char_alignments=False
                        )
                    except Exception as e:
                        print(f"Warning: Could not perform alignment: {e}")
                        # Continue without alignment
                
                # Ensure result format is compatible with whisper format
                if "segments" not in result:
                    result["segments"] = []
                
                # Ensure 'text' field exists - concatenate all segment texts
                if "text" not in result:
                    result["text"] = " ".join([segment.get("text", "") for segment in result.get("segments", [])])
                
                # Add words field to segments if word timestamps were requested
                for segment in result.get("segments", []):
                    if original_word_timestamps and "words" not in segment:
                        # If we don't have words but they were requested, create empty words list
                        segment["words"] = []
                
                return result
        
        # Create the alias module
        whisper_alias = WhisperAlias()
        
        # Add to sys.modules so 'import whisper' uses our alias
        sys.modules['whisper'] = whisper_alias
        
        print("✅ Successfully aliased whisper to whisperx")
        
    except ImportError as e:
        print(f"Warning: Could not setup whisper alias: {e}")
        print("Falling back to regular whisper (if available)")
    except Exception as e:
        print(f"Warning: Error setting up whisper alias: {e}")

# Auto-setup when module is imported
setup_whisper_alias()