hans00's picture
Alias whisper to whisperx
ca494e8 unverified
"""
Alias module to redirect whisper imports to whisperx.
This allows OuteTTS to use whisperx instead of the standard whisper package.
"""
import sys
import importlib.util
def setup_whisper_alias():
"""Setup alias so that 'import whisper' uses whisperx instead."""
try:
# Check if whisperx is available
whisperx_spec = importlib.util.find_spec("whisperx")
if whisperx_spec is None:
print("Warning: whisperx not found, falling back to regular whisper")
return
# Import whisperx
import whisperx
# Create a module wrapper that provides whisper-like interface
class WhisperAlias:
def __init__(self):
self.model = whisperx.WhisperModel if hasattr(whisperx, 'WhisperModel') else None
self.load_model = self._load_model
def _load_model(self, name, **kwargs):
"""Load model with whisperx compatible interface."""
# Create WhisperX model instance
device = "cuda" if kwargs.get("device", "auto") == "cuda" else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
model = whisperx.load_model(
name,
device=device,
compute_type=compute_type
)
return WhisperXModelWrapper(model, device)
class WhisperXModelWrapper:
"""Wrapper to make whisperx compatible with whisper interface."""
def __init__(self, model, device):
self.model = model
self.device = device
def transcribe(self, audio, **kwargs):
"""Transcribe audio with whisper-compatible interface."""
# Store original word_timestamps setting
original_word_timestamps = kwargs.get('word_timestamps', False)
# Load audio if it's a file path
if isinstance(audio, str):
audio_data = whisperx.load_audio(audio)
else:
audio_data = audio
# Use whisperx's transcribe method
batch_size = kwargs.get('batch_size', 16)
result = self.model.transcribe(audio_data, batch_size=batch_size)
# If word timestamps are requested, perform alignment
if original_word_timestamps and result.get("segments"):
try:
# Load alignment model
model_a, metadata = whisperx.load_align_model(
language_code=result.get("language", "en"),
device=self.device
)
# Align the segments
result = whisperx.align(
result["segments"],
model_a,
metadata,
audio_data,
self.device,
return_char_alignments=False
)
except Exception as e:
print(f"Warning: Could not perform alignment: {e}")
# Continue without alignment
# Ensure result format is compatible with whisper format
if "segments" not in result:
result["segments"] = []
# Ensure 'text' field exists - concatenate all segment texts
if "text" not in result:
result["text"] = " ".join([segment.get("text", "") for segment in result.get("segments", [])])
# Add words field to segments if word timestamps were requested
for segment in result.get("segments", []):
if original_word_timestamps and "words" not in segment:
# If we don't have words but they were requested, create empty words list
segment["words"] = []
return result
# Create the alias module
whisper_alias = WhisperAlias()
# Add to sys.modules so 'import whisper' uses our alias
sys.modules['whisper'] = whisper_alias
print("✅ Successfully aliased whisper to whisperx")
except ImportError as e:
print(f"Warning: Could not setup whisper alias: {e}")
print("Falling back to regular whisper (if available)")
except Exception as e:
print(f"Warning: Error setting up whisper alias: {e}")
# Auto-setup when module is imported
setup_whisper_alias()