Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Alias module to redirect whisper imports to whisperx. | |
This allows OuteTTS to use whisperx instead of the standard whisper package. | |
""" | |
import sys | |
import importlib.util | |
def setup_whisper_alias(): | |
"""Setup alias so that 'import whisper' uses whisperx instead.""" | |
try: | |
# Check if whisperx is available | |
whisperx_spec = importlib.util.find_spec("whisperx") | |
if whisperx_spec is None: | |
print("Warning: whisperx not found, falling back to regular whisper") | |
return | |
# Import whisperx | |
import whisperx | |
# Create a module wrapper that provides whisper-like interface | |
class WhisperAlias: | |
def __init__(self): | |
self.model = whisperx.WhisperModel if hasattr(whisperx, 'WhisperModel') else None | |
self.load_model = self._load_model | |
def _load_model(self, name, **kwargs): | |
"""Load model with whisperx compatible interface.""" | |
# Create WhisperX model instance | |
device = "cuda" if kwargs.get("device", "auto") == "cuda" else "cpu" | |
compute_type = "float16" if device == "cuda" else "int8" | |
model = whisperx.load_model( | |
name, | |
device=device, | |
compute_type=compute_type | |
) | |
return WhisperXModelWrapper(model, device) | |
class WhisperXModelWrapper: | |
"""Wrapper to make whisperx compatible with whisper interface.""" | |
def __init__(self, model, device): | |
self.model = model | |
self.device = device | |
def transcribe(self, audio, **kwargs): | |
"""Transcribe audio with whisper-compatible interface.""" | |
# Store original word_timestamps setting | |
original_word_timestamps = kwargs.get('word_timestamps', False) | |
# Load audio if it's a file path | |
if isinstance(audio, str): | |
audio_data = whisperx.load_audio(audio) | |
else: | |
audio_data = audio | |
# Use whisperx's transcribe method | |
batch_size = kwargs.get('batch_size', 16) | |
result = self.model.transcribe(audio_data, batch_size=batch_size) | |
# If word timestamps are requested, perform alignment | |
if original_word_timestamps and result.get("segments"): | |
try: | |
# Load alignment model | |
model_a, metadata = whisperx.load_align_model( | |
language_code=result.get("language", "en"), | |
device=self.device | |
) | |
# Align the segments | |
result = whisperx.align( | |
result["segments"], | |
model_a, | |
metadata, | |
audio_data, | |
self.device, | |
return_char_alignments=False | |
) | |
except Exception as e: | |
print(f"Warning: Could not perform alignment: {e}") | |
# Continue without alignment | |
# Ensure result format is compatible with whisper format | |
if "segments" not in result: | |
result["segments"] = [] | |
# Ensure 'text' field exists - concatenate all segment texts | |
if "text" not in result: | |
result["text"] = " ".join([segment.get("text", "") for segment in result.get("segments", [])]) | |
# Add words field to segments if word timestamps were requested | |
for segment in result.get("segments", []): | |
if original_word_timestamps and "words" not in segment: | |
# If we don't have words but they were requested, create empty words list | |
segment["words"] = [] | |
return result | |
# Create the alias module | |
whisper_alias = WhisperAlias() | |
# Add to sys.modules so 'import whisper' uses our alias | |
sys.modules['whisper'] = whisper_alias | |
print("✅ Successfully aliased whisper to whisperx") | |
except ImportError as e: | |
print(f"Warning: Could not setup whisper alias: {e}") | |
print("Falling back to regular whisper (if available)") | |
except Exception as e: | |
print(f"Warning: Error setting up whisper alias: {e}") | |
# Auto-setup when module is imported | |
setup_whisper_alias() |