Spaces:
Paused
Paused
import sys | |
import logging | |
import io | |
import soundfile as sf | |
import math | |
try: | |
import torch | |
except ImportError: | |
torch = None | |
from typing import List | |
import numpy as np | |
from timed_objects import ASRToken | |
logger = logging.getLogger(__name__) | |
class ASRBase: | |
sep = " " # join transcribe words with this character (" " for whisper_timestamped, | |
# "" for faster-whisper because it emits the spaces when needed) | |
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr): | |
self.logfile = logfile | |
self.transcribe_kargs = {} | |
if lan == "auto": | |
self.original_language = None | |
else: | |
self.original_language = lan | |
self.model = self.load_model(modelsize, cache_dir, model_dir) | |
def with_offset(self, offset: float) -> ASRToken: | |
# This method is kept for compatibility (typically you will use ASRToken.with_offset) | |
return ASRToken(self.start + offset, self.end + offset, self.text) | |
def __repr__(self): | |
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})" | |
def load_model(self, modelsize, cache_dir, model_dir): | |
raise NotImplementedError("must be implemented in the child class") | |
def transcribe(self, audio, init_prompt=""): | |
raise NotImplementedError("must be implemented in the child class") | |
def use_vad(self): | |
raise NotImplementedError("must be implemented in the child class") | |
class WhisperTimestampedASR(ASRBase): | |
"""Uses whisper_timestamped as the backend.""" | |
sep = " " | |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None): | |
print("Loading whisper_timestamped model") | |
import whisper | |
import whisper_timestamped | |
from whisper_timestamped import transcribe_timestamped | |
self.transcribe_timestamped = transcribe_timestamped | |
if model_dir is not None: | |
logger.debug("ignoring model_dir, not implemented") | |
return whisper.load_model(modelsize, download_root=cache_dir) | |
def transcribe(self, audio, init_prompt=""): | |
result = self.transcribe_timestamped( | |
self.model, | |
audio, | |
language=self.original_language, | |
initial_prompt=init_prompt, | |
verbose=None, | |
condition_on_previous_text=True, | |
**self.transcribe_kargs, | |
) | |
return result | |
def ts_words(self, r) -> List[ASRToken]: | |
""" | |
Converts the whisper_timestamped result to a list of ASRToken objects. | |
""" | |
tokens = [] | |
for segment in r["segments"]: | |
for word in segment["words"]: | |
token = ASRToken(word["start"], word["end"], word["text"]) | |
tokens.append(token) | |
return tokens | |
def segments_end_ts(self, res) -> List[float]: | |
return [segment["end"] for segment in res["segments"]] | |
def use_vad(self): | |
self.transcribe_kargs["vad"] = True | |
def set_translate_task(self): | |
self.transcribe_kargs["task"] = "translate" | |
def detect_language(self, audio_file_path): | |
import whisper | |
""" | |
Detect the language of the audio using Whisper's language detection. | |
Args: | |
audio (np.ndarray): Audio data as numpy array | |
Returns: | |
tuple: (detected_language, confidence, probabilities) | |
- detected_language (str): The detected language code | |
- confidence (float): Confidence score for the detected language | |
- probabilities (dict): Dictionary of language probabilities | |
""" | |
try: | |
# Pad or trim audio to the correct length | |
audio = whisper.load_audio(audio_file_path) | |
audio = whisper.pad_or_trim(audio) | |
# Create mel spectrogram with correct dimensions | |
mel = whisper.log_mel_spectrogram(audio, n_mels=128).to(self.model.device) | |
# Detect language | |
_, probs = self.model.detect_language(mel) | |
detected_lang = max(probs, key=probs.get) | |
confidence = probs[detected_lang] | |
return detected_lang, confidence, probs | |
except Exception as e: | |
logger.error(f"Error in language detection: {e}") | |
raise | |
class FasterWhisperASR(ASRBase): | |
"""Uses faster-whisper as the backend.""" | |
sep = "" | |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None): | |
print("Loading faster-whisper model") | |
from faster_whisper import WhisperModel | |
if model_dir is not None: | |
logger.debug(f"Loading whisper model from model_dir {model_dir}. " | |
f"modelsize and cache_dir parameters are not used.") | |
model_size_or_path = model_dir | |
elif modelsize is not None: | |
model_size_or_path = modelsize | |
else: | |
raise ValueError("Either modelsize or model_dir must be set") | |
device = "cuda" if torch and torch.cuda.is_available() else "cpu" | |
compute_type = "float16" if device == "cuda" else "float32" | |
print(f"Loading whisper model {model_size_or_path} on {device} with compute type {compute_type}") | |
model = WhisperModel( | |
model_size_or_path, | |
device=device, | |
compute_type=compute_type, | |
download_root=cache_dir, | |
) | |
return model | |
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list: | |
segments, info = self.model.transcribe( | |
audio, | |
language=None, | |
initial_prompt=init_prompt, | |
beam_size=5, | |
word_timestamps=True, | |
condition_on_previous_text=True, | |
**self.transcribe_kargs, | |
) | |
return list(segments) | |
def ts_words(self, segments) -> List[ASRToken]: | |
tokens = [] | |
for segment in segments: | |
if segment.no_speech_prob > 0.9: | |
continue | |
for word in segment.words: | |
token = ASRToken(word.start, word.end, word.word, probability=word.probability) | |
tokens.append(token) | |
return tokens | |
def segments_end_ts(self, segments) -> List[float]: | |
return [segment.end for segment in segments] | |
def use_vad(self): | |
self.transcribe_kargs["vad_filter"] = True | |
def set_translate_task(self): | |
self.transcribe_kargs["task"] = "translate" | |
def detect_language(self, audio_file_path): | |
from faster_whisper.audio import decode_audio | |
""" | |
Detect the language of the audio using faster-whisper's language detection. | |
Args: | |
audio_file_path: Path to the audio file | |
Returns: | |
tuple: (detected_language, confidence, probabilities) | |
- detected_language (str): The detected language code | |
- confidence (float): Confidence score for the detected language | |
- probabilities (dict): Dictionary of language probabilities | |
""" | |
try: | |
audio = decode_audio(audio_file_path, sampling_rate=self.model.feature_extractor.sampling_rate) | |
# Calculate total number of segments (each segment is 30 seconds) | |
audio_duration = len(audio) / self.model.feature_extractor.sampling_rate | |
segments_num = max(1, int(audio_duration / 30)) # At least 1 segment | |
logger.info(f"Audio duration: {audio_duration:.2f}s, using {segments_num} segments for language detection") | |
# Use faster-whisper's detect_language method | |
language, language_probability, all_language_probs = self.model.detect_language( | |
audio=audio, | |
vad_filter=False, # Disable VAD for language detection | |
language_detection_segments=segments_num, # Use all possible segments | |
language_detection_threshold=0.5 # Default threshold | |
) | |
# Convert list of tuples to dictionary for consistent return format | |
probs = {lang: prob for lang, prob in all_language_probs} | |
return language, language_probability, probs | |
except Exception as e: | |
logger.error(f"Error in language detection: {e}") | |
raise | |
class MLXWhisper(ASRBase): | |
""" | |
Uses MLX Whisper optimized for Apple Silicon. | |
""" | |
sep = "" | |
def load_model(self, modelsize=None, cache_dir=None, model_dir=None): | |
print("Loading mlx whisper model") | |
from mlx_whisper.transcribe import ModelHolder, transcribe | |
import mlx.core as mx | |
if model_dir is not None: | |
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.") | |
model_size_or_path = model_dir | |
elif modelsize is not None: | |
model_size_or_path = self.translate_model_name(modelsize) | |
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.") | |
else: | |
raise ValueError("Either modelsize or model_dir must be set") | |
self.model_size_or_path = model_size_or_path | |
dtype = mx.float16 | |
ModelHolder.get_model(model_size_or_path, dtype) | |
return transcribe | |
def translate_model_name(self, model_name): | |
model_mapping = { | |
"tiny.en": "mlx-community/whisper-tiny.en-mlx", | |
"tiny": "mlx-community/whisper-tiny-mlx", | |
"base.en": "mlx-community/whisper-base.en-mlx", | |
"base": "mlx-community/whisper-base-mlx", | |
"small.en": "mlx-community/whisper-small.en-mlx", | |
"small": "mlx-community/whisper-small-mlx", | |
"medium.en": "mlx-community/whisper-medium.en-mlx", | |
"medium": "mlx-community/whisper-medium-mlx", | |
"large-v1": "mlx-community/whisper-large-v1-mlx", | |
"large-v2": "mlx-community/whisper-large-v2-mlx", | |
"large-v3": "mlx-community/whisper-large-v3-mlx", | |
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo", | |
"large": "mlx-community/whisper-large-mlx", | |
} | |
mlx_model_path = model_mapping.get(model_name) | |
if mlx_model_path: | |
return mlx_model_path | |
else: | |
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.") | |
def transcribe(self, audio, init_prompt=""): | |
if self.transcribe_kargs: | |
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.") | |
segments = self.model( | |
audio, | |
language=self.original_language, | |
initial_prompt=init_prompt, | |
word_timestamps=True, | |
condition_on_previous_text=True, | |
path_or_hf_repo=self.model_size_or_path, | |
) | |
return segments.get("segments", []) | |
def ts_words(self, segments) -> List[ASRToken]: | |
tokens = [] | |
for segment in segments: | |
if segment.get("no_speech_prob", 0) > 0.9: | |
continue | |
for word in segment.get("words", []): | |
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"]) | |
tokens.append(token) | |
return tokens | |
def segments_end_ts(self, res) -> List[float]: | |
return [s["end"] for s in res] | |
def use_vad(self): | |
self.transcribe_kargs["vad_filter"] = True | |
def set_translate_task(self): | |
self.transcribe_kargs["task"] = "translate" | |
def detect_language(self, audio): | |
raise NotImplementedError("MLX Whisper does not support language detection.") | |
class OpenaiApiASR(ASRBase): | |
"""Uses OpenAI's Whisper API for transcription.""" | |
def __init__(self, lan=None, temperature=0, logfile=sys.stderr): | |
print("Loading openai api model") | |
self.logfile = logfile | |
self.modelname = "whisper-1" | |
self.original_language = None if lan == "auto" else lan | |
self.response_format = "verbose_json" | |
self.temperature = temperature | |
self.load_model() | |
self.use_vad_opt = False | |
self.task = "transcribe" | |
def load_model(self, *args, **kwargs): | |
from openai import OpenAI | |
self.client = OpenAI() | |
self.transcribed_seconds = 0 | |
def ts_words(self, segments) -> List[ASRToken]: | |
""" | |
Converts OpenAI API response words into ASRToken objects while | |
optionally skipping words that fall into no-speech segments. | |
""" | |
no_speech_segments = [] | |
if self.use_vad_opt: | |
for segment in segments.segments: | |
if segment.no_speech_prob > 0.8: | |
no_speech_segments.append((segment.start, segment.end)) | |
tokens = [] | |
for word in segments.words: | |
start = word.start | |
end = word.end | |
if any(s[0] <= start <= s[1] for s in no_speech_segments): | |
continue | |
tokens.append(ASRToken(start, end, word.word)) | |
return tokens | |
def segments_end_ts(self, res) -> List[float]: | |
return [s.end for s in res.words] | |
def transcribe(self, audio_data, prompt=None, *args, **kwargs): | |
buffer = io.BytesIO() | |
buffer.name = "temp.wav" | |
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16") | |
buffer.seek(0) | |
self.transcribed_seconds += math.ceil(len(audio_data) / 16000) | |
params = { | |
"model": self.modelname, | |
"file": buffer, | |
"response_format": self.response_format, | |
"temperature": self.temperature, | |
"timestamp_granularities": ["word", "segment"], | |
} | |
if self.task != "translate" and self.original_language: | |
params["language"] = self.original_language | |
if prompt: | |
params["prompt"] = prompt | |
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions | |
transcript = proc.create(**params) | |
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds") | |
return transcript | |
def use_vad(self): | |
self.use_vad_opt = True | |
def set_translate_task(self): | |
self.task = "translate" | |
def detect_language(self, audio): | |
raise NotImplementedError("MLX Whisper does not support language detection.") |