import sys import logging import io import soundfile as sf import math try: import torch except ImportError: torch = None from typing import List import numpy as np from timed_objects import ASRToken logger = logging.getLogger(__name__) class ASRBase: sep = " " # join transcribe words with this character (" " for whisper_timestamped, # "" for faster-whisper because it emits the spaces when needed) def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr): self.logfile = logfile self.transcribe_kargs = {} if lan == "auto": self.original_language = None else: self.original_language = lan self.model = self.load_model(modelsize, cache_dir, model_dir) def with_offset(self, offset: float) -> ASRToken: # This method is kept for compatibility (typically you will use ASRToken.with_offset) return ASRToken(self.start + offset, self.end + offset, self.text) def __repr__(self): return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})" def load_model(self, modelsize, cache_dir, model_dir): raise NotImplementedError("must be implemented in the child class") def transcribe(self, audio, init_prompt=""): raise NotImplementedError("must be implemented in the child class") def use_vad(self): raise NotImplementedError("must be implemented in the child class") class WhisperTimestampedASR(ASRBase): """Uses whisper_timestamped as the backend.""" sep = " " def load_model(self, modelsize=None, cache_dir=None, model_dir=None): print("Loading whisper_timestamped model") import whisper import whisper_timestamped from whisper_timestamped import transcribe_timestamped self.transcribe_timestamped = transcribe_timestamped if model_dir is not None: logger.debug("ignoring model_dir, not implemented") return whisper.load_model(modelsize, download_root=cache_dir) def transcribe(self, audio, init_prompt=""): result = self.transcribe_timestamped( self.model, audio, language=self.original_language, initial_prompt=init_prompt, verbose=None, condition_on_previous_text=True, **self.transcribe_kargs, ) return result def ts_words(self, r) -> List[ASRToken]: """ Converts the whisper_timestamped result to a list of ASRToken objects. """ tokens = [] for segment in r["segments"]: for word in segment["words"]: token = ASRToken(word["start"], word["end"], word["text"]) tokens.append(token) return tokens def segments_end_ts(self, res) -> List[float]: return [segment["end"] for segment in res["segments"]] def use_vad(self): self.transcribe_kargs["vad"] = True def set_translate_task(self): self.transcribe_kargs["task"] = "translate" def detect_language(self, audio_file_path): import whisper """ Detect the language of the audio using Whisper's language detection. Args: audio (np.ndarray): Audio data as numpy array Returns: tuple: (detected_language, confidence, probabilities) - detected_language (str): The detected language code - confidence (float): Confidence score for the detected language - probabilities (dict): Dictionary of language probabilities """ try: # Pad or trim audio to the correct length audio = whisper.load_audio(audio_file_path) audio = whisper.pad_or_trim(audio) # Create mel spectrogram with correct dimensions mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(self.model.device) # Detect language _, probs = self.model.detect_language(mel) detected_lang = max(probs, key=probs.get) confidence = probs[detected_lang] return detected_lang, confidence, probs except Exception as e: logger.error(f"Error in language detection: {e}") raise class FasterWhisperASR(ASRBase): """Uses faster-whisper as the backend.""" sep = "" def load_model(self, modelsize=None, cache_dir=None, model_dir=None): print("Loading faster-whisper model") from faster_whisper import WhisperModel if model_dir is not None: logger.debug(f"Loading whisper model from model_dir {model_dir}. " f"modelsize and cache_dir parameters are not used.") model_size_or_path = model_dir elif modelsize is not None: model_size_or_path = modelsize else: raise ValueError("Either modelsize or model_dir must be set") device = "cuda" if torch and torch.cuda.is_available() else "cpu" compute_type = "float16" if device == "cuda" else "float32" print(f"Loading whisper model {model_size_or_path} on {device} with compute type {compute_type}") model = WhisperModel( model_size_or_path, device=device, compute_type=compute_type, download_root=cache_dir, ) return model def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list: segments, info = self.model.transcribe( audio, language=None, initial_prompt=init_prompt, beam_size=5, word_timestamps=True, condition_on_previous_text=True, **self.transcribe_kargs, ) return list(segments) def ts_words(self, segments) -> List[ASRToken]: tokens = [] for segment in segments: if segment.no_speech_prob > 0.9: continue for word in segment.words: token = ASRToken(word.start, word.end, word.word, probability=word.probability) tokens.append(token) return tokens def segments_end_ts(self, segments) -> List[float]: return [segment.end for segment in segments] def use_vad(self): self.transcribe_kargs["vad_filter"] = True def set_translate_task(self): self.transcribe_kargs["task"] = "translate" def detect_language(self, audio_file_path): from faster_whisper.audio import decode_audio """ Detect the language of the audio using faster-whisper's language detection. Args: audio_file_path: Path to the audio file Returns: tuple: (detected_language, confidence, probabilities) - detected_language (str): The detected language code - confidence (float): Confidence score for the detected language - probabilities (dict): Dictionary of language probabilities """ try: audio = decode_audio(audio_file_path, sampling_rate=self.model.feature_extractor.sampling_rate) # Calculate total number of segments (each segment is 30 seconds) audio_duration = len(audio) / self.model.feature_extractor.sampling_rate segments_num = max(1, int(audio_duration / 30)) # At least 1 segment logger.info(f"Audio duration: {audio_duration:.2f}s, using {segments_num} segments for language detection") # Use faster-whisper's detect_language method language, language_probability, all_language_probs = self.model.detect_language( audio=audio, vad_filter=False, # Disable VAD for language detection language_detection_segments=segments_num, # Use all possible segments language_detection_threshold=0.5 # Default threshold ) # Convert list of tuples to dictionary for consistent return format probs = {lang: prob for lang, prob in all_language_probs} return language, language_probability, probs except Exception as e: logger.error(f"Error in language detection: {e}") raise class MLXWhisper(ASRBase): """ Uses MLX Whisper optimized for Apple Silicon. """ sep = "" def load_model(self, modelsize=None, cache_dir=None, model_dir=None): print("Loading mlx whisper model") from mlx_whisper.transcribe import ModelHolder, transcribe import mlx.core as mx if model_dir is not None: logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") model_size_or_path = model_dir elif modelsize is not None: model_size_or_path = self.translate_model_name(modelsize) logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.") else: raise ValueError("Either modelsize or model_dir must be set") self.model_size_or_path = model_size_or_path dtype = mx.float16 ModelHolder.get_model(model_size_or_path, dtype) return transcribe def translate_model_name(self, model_name): model_mapping = { "tiny.en": "mlx-community/whisper-tiny.en-mlx", "tiny": "mlx-community/whisper-tiny-mlx", "base.en": "mlx-community/whisper-base.en-mlx", "base": "mlx-community/whisper-base-mlx", "small.en": "mlx-community/whisper-small.en-mlx", "small": "mlx-community/whisper-small-mlx", "medium.en": "mlx-community/whisper-medium.en-mlx", "medium": "mlx-community/whisper-medium-mlx", "large-v1": "mlx-community/whisper-large-v1-mlx", "large-v2": "mlx-community/whisper-large-v2-mlx", "large-v3": "mlx-community/whisper-large-v3-mlx", "large-v3-turbo": "mlx-community/whisper-large-v3-turbo", "large": "mlx-community/whisper-large-mlx", } mlx_model_path = model_mapping.get(model_name) if mlx_model_path: return mlx_model_path else: raise ValueError(f"Model name '{model_name}' is not recognized or not supported.") def transcribe(self, audio, init_prompt=""): if self.transcribe_kargs: logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.") segments = self.model( audio, language=self.original_language, initial_prompt=init_prompt, word_timestamps=True, condition_on_previous_text=True, path_or_hf_repo=self.model_size_or_path, ) return segments.get("segments", []) def ts_words(self, segments) -> List[ASRToken]: tokens = [] for segment in segments: if segment.get("no_speech_prob", 0) > 0.9: continue for word in segment.get("words", []): token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"]) tokens.append(token) return tokens def segments_end_ts(self, res) -> List[float]: return [s["end"] for s in res] def use_vad(self): self.transcribe_kargs["vad_filter"] = True def set_translate_task(self): self.transcribe_kargs["task"] = "translate" def detect_language(self, audio): raise NotImplementedError("MLX Whisper does not support language detection.") class OpenaiApiASR(ASRBase): """Uses OpenAI's Whisper API for transcription.""" def __init__(self, lan=None, temperature=0, logfile=sys.stderr): print("Loading openai api model") self.logfile = logfile self.modelname = "whisper-1" self.original_language = None if lan == "auto" else lan self.response_format = "verbose_json" self.temperature = temperature self.load_model() self.use_vad_opt = False self.task = "transcribe" def load_model(self, *args, **kwargs): from openai import OpenAI self.client = OpenAI() self.transcribed_seconds = 0 def ts_words(self, segments) -> List[ASRToken]: """ Converts OpenAI API response words into ASRToken objects while optionally skipping words that fall into no-speech segments. """ no_speech_segments = [] if self.use_vad_opt: for segment in segments.segments: if segment.no_speech_prob > 0.8: no_speech_segments.append((segment.start, segment.end)) tokens = [] for word in segments.words: start = word.start end = word.end if any(s[0] <= start <= s[1] for s in no_speech_segments): continue tokens.append(ASRToken(start, end, word.word)) return tokens def segments_end_ts(self, res) -> List[float]: return [s.end for s in res.words] def transcribe(self, audio_data, prompt=None, *args, **kwargs): buffer = io.BytesIO() buffer.name = "temp.wav" sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16") buffer.seek(0) self.transcribed_seconds += math.ceil(len(audio_data) / 16000) params = { "model": self.modelname, "file": buffer, "response_format": self.response_format, "temperature": self.temperature, "timestamp_granularities": ["word", "segment"], } if self.task != "translate" and self.original_language: params["language"] = self.original_language if prompt: params["prompt"] = prompt proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions transcript = proc.create(**params) logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") return transcript def use_vad(self): self.use_vad_opt = True def set_translate_task(self): self.task = "translate" def detect_language(self, audio): raise NotImplementedError("MLX Whisper does not support language detection.")