AnyaSchen commited on
Commit
13db51f
·
1 Parent(s): 070d9af

feat: try to add language detector

Browse files
Files changed (2) hide show
  1. language_detector.py +84 -0
  2. main.py +36 -2
language_detector.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import numpy as np
3
+ import logging
4
+ import io
5
+ import librosa
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class LanguageDetector:
10
+ def __init__(self, model_name="tiny"):
11
+ """
12
+ Initialize the language detector with a Whisper model.
13
+
14
+ Args:
15
+ model_name (str): Name of the Whisper model to use. Default is "tiny" which is sufficient for language detection.
16
+ """
17
+ self.model = whisper.load_model(model_name)
18
+ logger.info(f"Loaded Whisper model {model_name} for language detection")
19
+
20
+ def detect_language_from_file(self, audio_file_path):
21
+ """
22
+ Detect language from an audio file.
23
+
24
+ Args:
25
+ audio_file_path (str): Path to the audio file
26
+
27
+ Returns:
28
+ str: Detected language code (e.g., "en", "fr", etc.)
29
+ float: Confidence score
30
+ """
31
+ try:
32
+ # Load and preprocess audio
33
+ audio = whisper.load_audio(audio_file_path)
34
+ audio = whisper.pad_or_trim(audio)
35
+
36
+ # Make log-Mel spectrogram
37
+ mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
38
+
39
+ # Detect language
40
+ _, probs = self.model.detect_language(mel)
41
+ detected_lang = max(probs, key=probs.get)
42
+ confidence = probs[detected_lang]
43
+
44
+ return detected_lang, confidence
45
+
46
+ except Exception as e:
47
+ logger.error(f"Error in language detection: {e}")
48
+ raise
49
+
50
+ def detect_language_from_bytes(self, audio_bytes):
51
+ """
52
+ Detect language from audio bytes.
53
+
54
+ Args:
55
+ audio_bytes (bytes): Audio data in bytes
56
+
57
+ Returns:
58
+ str: Detected language code (e.g., "en", "fr", etc.)
59
+ float: Confidence score
60
+ """
61
+ try:
62
+ # Convert bytes to numpy array using librosa
63
+ audio_data = io.BytesIO(audio_bytes)
64
+ audio, sr = librosa.load(audio_data, sr=16000)
65
+
66
+ # Convert to format expected by Whisper
67
+ audio = (audio * 32768).astype(np.int16)
68
+
69
+ # Load and preprocess audio
70
+ audio = whisper.pad_or_trim(audio)
71
+
72
+ # Make log-Mel spectrogram
73
+ mel = whisper.log_mel_spectrogram(audio).to(self.model.device)
74
+
75
+ # Detect language
76
+ _, probs = self.model.detect_language(mel)
77
+ detected_lang = max(probs, key=probs.get)
78
+ confidence = probs[detected_lang]
79
+
80
+ return detected_lang, confidence
81
+
82
+ except Exception as e:
83
+ logger.error(f"Error in language detection: {e}")
84
+ raise
main.py CHANGED
@@ -1,5 +1,5 @@
1
  from contextlib import asynccontextmanager
2
- from fastapi import FastAPI, WebSocket, WebSocketDisconnect
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import JSONResponse
5
  from fastapi.staticfiles import StaticFiles
@@ -10,9 +10,13 @@ import os
10
  import traceback
11
  import argparse
12
  import uvicorn
 
 
 
13
 
14
  from core import WhisperLiveKit
15
  from audio_processor import AudioProcessor
 
16
 
17
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
18
  logging.getLogger().setLevel(logging.WARNING)
@@ -20,11 +24,13 @@ logger = logging.getLogger(__name__)
20
  logger.setLevel(logging.DEBUG)
21
 
22
  kit = None
 
23
 
24
  @asynccontextmanager
25
  async def lifespan(app: FastAPI):
26
- global kit
27
  kit = WhisperLiveKit()
 
28
  yield
29
 
30
  app = FastAPI(lifespan=lifespan)
@@ -50,6 +56,34 @@ async def read_root():
50
  async def health_check():
51
  return JSONResponse({"status": "healthy"})
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  async def handle_websocket_results(websocket, results_generator):
54
  """Consumes results from the audio processor and sends them via WebSocket."""
55
  try:
 
1
  from contextlib import asynccontextmanager
2
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, UploadFile, File
3
  from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import JSONResponse
5
  from fastapi.staticfiles import StaticFiles
 
10
  import traceback
11
  import argparse
12
  import uvicorn
13
+ import numpy as np
14
+ import librosa
15
+ import io
16
 
17
  from core import WhisperLiveKit
18
  from audio_processor import AudioProcessor
19
+ from language_detector import LanguageDetector
20
 
21
  logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
22
  logging.getLogger().setLevel(logging.WARNING)
 
24
  logger.setLevel(logging.DEBUG)
25
 
26
  kit = None
27
+ language_detector = None
28
 
29
  @asynccontextmanager
30
  async def lifespan(app: FastAPI):
31
+ global kit, language_detector
32
  kit = WhisperLiveKit()
33
+ language_detector = LanguageDetector(model_name="tiny")
34
  yield
35
 
36
  app = FastAPI(lifespan=lifespan)
 
56
  async def health_check():
57
  return JSONResponse({"status": "healthy"})
58
 
59
+ @app.post("/detect-language")
60
+ async def detect_language(file: UploadFile = File(...)):
61
+ try:
62
+ # Read the audio file
63
+ contents = await file.read()
64
+
65
+ # Use the language detector
66
+ if language_detector:
67
+ detected_lang, confidence = language_detector.detect_language_from_bytes(contents)
68
+
69
+ return JSONResponse({
70
+ "language": detected_lang,
71
+ "confidence": float(confidence)
72
+ })
73
+ else:
74
+ return JSONResponse(
75
+ {"error": "Language detector not initialized"},
76
+ status_code=500
77
+ )
78
+
79
+ except Exception as e:
80
+ logger.error(f"Error in language detection: {e}")
81
+ logger.error(f"Traceback: {traceback.format_exc()}")
82
+ return JSONResponse(
83
+ {"error": str(e)},
84
+ status_code=500
85
+ )
86
+
87
  async def handle_websocket_results(websocket, results_generator):
88
  """Consumes results from the audio processor and sends them via WebSocket."""
89
  try: