whisper-websocket / language_detector.py
AnyaSchen's picture
feat: try to add language detector 3
b8a4e79
raw
history blame
2.77 kB
import whisper as whp
import numpy as np
import logging
import io
import librosa
logger = logging.getLogger(__name__)
class LanguageDetector:
def __init__(self, model_name="tiny"):
"""
Initialize the language detector with a Whisper model.
Args:
model_name (str): Name of the Whisper model to use. Default is "tiny" which is sufficient for language detection.
"""
self.model = whp.load_model(model_name)
logger.info(f"Loaded Whisper model {model_name} for language detection")
def detect_language_from_file(self, audio_file_path):
"""
Detect language from an audio file.
Args:
audio_file_path (str): Path to the audio file
Returns:
str: Detected language code (e.g., "en", "fr", etc.)
float: Confidence score
"""
try:
# Load and preprocess audio
audio = whp.load_audio(audio_file_path)
audio = whp.pad_or_trim(audio)
# Make log-Mel spectrogram
mel = whp.log_mel_spectrogram(audio).to(self.model.device)
# Detect language
_, probs = self.model.detect_language(mel)
detected_lang = max(probs, key=probs.get)
confidence = probs[detected_lang]
return detected_lang, confidence
except Exception as e:
logger.error(f"Error in language detection: {e}")
raise
def detect_language_from_bytes(self, audio_bytes):
"""
Detect language from audio bytes.
Args:
audio_bytes (bytes): Audio data in bytes
Returns:
str: Detected language code (e.g., "en", "fr", etc.)
float: Confidence score
"""
try:
# Convert bytes to numpy array using librosa
audio_data = io.BytesIO(audio_bytes)
audio, sr = librosa.load(audio_data, sr=16000)
# Convert to format expected by Whisper
audio = (audio * 32768).astype(np.int16)
# Load and preprocess audio
audio = whp.pad_or_trim(audio)
# Make log-Mel spectrogram
mel = whp.log_mel_spectrogram(audio).to(self.model.device)
# Detect language
_, probs = self.model.detect_language(mel)
detected_lang = max(probs, key=probs.get)
confidence = probs[detected_lang]
return detected_lang, confidence
except Exception as e:
logger.error(f"Error in language detection: {e}")
raise