Spaces:
Paused
Paused
feat: try to add language detector
Browse files- language_detector.py +84 -0
- main.py +36 -2
language_detector.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import whisper
|
2 |
+
import numpy as np
|
3 |
+
import logging
|
4 |
+
import io
|
5 |
+
import librosa
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
class LanguageDetector:
|
10 |
+
def __init__(self, model_name="tiny"):
|
11 |
+
"""
|
12 |
+
Initialize the language detector with a Whisper model.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
model_name (str): Name of the Whisper model to use. Default is "tiny" which is sufficient for language detection.
|
16 |
+
"""
|
17 |
+
self.model = whisper.load_model(model_name)
|
18 |
+
logger.info(f"Loaded Whisper model {model_name} for language detection")
|
19 |
+
|
20 |
+
def detect_language_from_file(self, audio_file_path):
|
21 |
+
"""
|
22 |
+
Detect language from an audio file.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
audio_file_path (str): Path to the audio file
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
str: Detected language code (e.g., "en", "fr", etc.)
|
29 |
+
float: Confidence score
|
30 |
+
"""
|
31 |
+
try:
|
32 |
+
# Load and preprocess audio
|
33 |
+
audio = whisper.load_audio(audio_file_path)
|
34 |
+
audio = whisper.pad_or_trim(audio)
|
35 |
+
|
36 |
+
# Make log-Mel spectrogram
|
37 |
+
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
|
38 |
+
|
39 |
+
# Detect language
|
40 |
+
_, probs = self.model.detect_language(mel)
|
41 |
+
detected_lang = max(probs, key=probs.get)
|
42 |
+
confidence = probs[detected_lang]
|
43 |
+
|
44 |
+
return detected_lang, confidence
|
45 |
+
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"Error in language detection: {e}")
|
48 |
+
raise
|
49 |
+
|
50 |
+
def detect_language_from_bytes(self, audio_bytes):
|
51 |
+
"""
|
52 |
+
Detect language from audio bytes.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
audio_bytes (bytes): Audio data in bytes
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
str: Detected language code (e.g., "en", "fr", etc.)
|
59 |
+
float: Confidence score
|
60 |
+
"""
|
61 |
+
try:
|
62 |
+
# Convert bytes to numpy array using librosa
|
63 |
+
audio_data = io.BytesIO(audio_bytes)
|
64 |
+
audio, sr = librosa.load(audio_data, sr=16000)
|
65 |
+
|
66 |
+
# Convert to format expected by Whisper
|
67 |
+
audio = (audio * 32768).astype(np.int16)
|
68 |
+
|
69 |
+
# Load and preprocess audio
|
70 |
+
audio = whisper.pad_or_trim(audio)
|
71 |
+
|
72 |
+
# Make log-Mel spectrogram
|
73 |
+
mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
|
74 |
+
|
75 |
+
# Detect language
|
76 |
+
_, probs = self.model.detect_language(mel)
|
77 |
+
detected_lang = max(probs, key=probs.get)
|
78 |
+
confidence = probs[detected_lang]
|
79 |
+
|
80 |
+
return detected_lang, confidence
|
81 |
+
|
82 |
+
except Exception as e:
|
83 |
+
logger.error(f"Error in language detection: {e}")
|
84 |
+
raise
|
main.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from contextlib import asynccontextmanager
|
2 |
-
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
from fastapi.responses import JSONResponse
|
5 |
from fastapi.staticfiles import StaticFiles
|
@@ -10,9 +10,13 @@ import os
|
|
10 |
import traceback
|
11 |
import argparse
|
12 |
import uvicorn
|
|
|
|
|
|
|
13 |
|
14 |
from core import WhisperLiveKit
|
15 |
from audio_processor import AudioProcessor
|
|
|
16 |
|
17 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
18 |
logging.getLogger().setLevel(logging.WARNING)
|
@@ -20,11 +24,13 @@ logger = logging.getLogger(__name__)
|
|
20 |
logger.setLevel(logging.DEBUG)
|
21 |
|
22 |
kit = None
|
|
|
23 |
|
24 |
@asynccontextmanager
|
25 |
async def lifespan(app: FastAPI):
|
26 |
-
global kit
|
27 |
kit = WhisperLiveKit()
|
|
|
28 |
yield
|
29 |
|
30 |
app = FastAPI(lifespan=lifespan)
|
@@ -50,6 +56,34 @@ async def read_root():
|
|
50 |
async def health_check():
|
51 |
return JSONResponse({"status": "healthy"})
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
async def handle_websocket_results(websocket, results_generator):
|
54 |
"""Consumes results from the audio processor and sends them via WebSocket."""
|
55 |
try:
|
|
|
1 |
from contextlib import asynccontextmanager
|
2 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
|
3 |
from fastapi.middleware.cors import CORSMiddleware
|
4 |
from fastapi.responses import JSONResponse
|
5 |
from fastapi.staticfiles import StaticFiles
|
|
|
10 |
import traceback
|
11 |
import argparse
|
12 |
import uvicorn
|
13 |
+
import numpy as np
|
14 |
+
import librosa
|
15 |
+
import io
|
16 |
|
17 |
from core import WhisperLiveKit
|
18 |
from audio_processor import AudioProcessor
|
19 |
+
from language_detector import LanguageDetector
|
20 |
|
21 |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
22 |
logging.getLogger().setLevel(logging.WARNING)
|
|
|
24 |
logger.setLevel(logging.DEBUG)
|
25 |
|
26 |
kit = None
|
27 |
+
language_detector = None
|
28 |
|
29 |
@asynccontextmanager
|
30 |
async def lifespan(app: FastAPI):
|
31 |
+
global kit, language_detector
|
32 |
kit = WhisperLiveKit()
|
33 |
+
language_detector = LanguageDetector(model_name="tiny")
|
34 |
yield
|
35 |
|
36 |
app = FastAPI(lifespan=lifespan)
|
|
|
56 |
async def health_check():
|
57 |
return JSONResponse({"status": "healthy"})
|
58 |
|
59 |
+
@app.post("/detect-language")
|
60 |
+
async def detect_language(file: UploadFile = File(...)):
|
61 |
+
try:
|
62 |
+
# Read the audio file
|
63 |
+
contents = await file.read()
|
64 |
+
|
65 |
+
# Use the language detector
|
66 |
+
if language_detector:
|
67 |
+
detected_lang, confidence = language_detector.detect_language_from_bytes(contents)
|
68 |
+
|
69 |
+
return JSONResponse({
|
70 |
+
"language": detected_lang,
|
71 |
+
"confidence": float(confidence)
|
72 |
+
})
|
73 |
+
else:
|
74 |
+
return JSONResponse(
|
75 |
+
{"error": "Language detector not initialized"},
|
76 |
+
status_code=500
|
77 |
+
)
|
78 |
+
|
79 |
+
except Exception as e:
|
80 |
+
logger.error(f"Error in language detection: {e}")
|
81 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
82 |
+
return JSONResponse(
|
83 |
+
{"error": str(e)},
|
84 |
+
status_code=500
|
85 |
+
)
|
86 |
+
|
87 |
async def handle_websocket_results(websocket, results_generator):
|
88 |
"""Consumes results from the audio processor and sends them via WebSocket."""
|
89 |
try:
|